random-api-go/main.go
wood chen 96f41e374f refactor(main, logging, utils): Refactor logging and IP extraction logic
Move setupLogging to logging package and create GetRealIP utility function
2024-10-26 15:46:00 +08:00

272 lines
6.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"random-api-go/logging"
"random-api-go/stats"
"random-api-go/utils"
"strings"
"sync"
"syscall"
"time"
)
const (
port = ":5003"
cacheDuration = 24 * time.Hour
requestTimeout = 10 * time.Second
noRepeatCount = 3 // 在这个次数内不重复选择
)
var (
csvPathsCache map[string]map[string]string
lastFetchTime time.Time
csvCache = make(map[string]*URLSelector)
mu sync.RWMutex
rng *rand.Rand
statsManager *stats.StatsManager
)
type URLSelector struct {
URLs []string
CurrentIndex int
RecentUsed map[string]int
mu sync.Mutex
}
func NewURLSelector(urls []string) *URLSelector {
return &URLSelector{
URLs: urls,
CurrentIndex: 0,
RecentUsed: make(map[string]int),
}
}
func (us *URLSelector) ShuffleURLs() {
for i := len(us.URLs) - 1; i > 0; i-- {
j := rng.Intn(i + 1)
us.URLs[i], us.URLs[j] = us.URLs[j], us.URLs[i]
}
}
func (us *URLSelector) GetRandomURL() string {
us.mu.Lock()
defer us.mu.Unlock()
if us.CurrentIndex == 0 {
us.ShuffleURLs()
}
for i := 0; i < len(us.URLs); i++ {
url := us.URLs[us.CurrentIndex]
us.CurrentIndex = (us.CurrentIndex + 1) % len(us.URLs)
if us.RecentUsed[url] < noRepeatCount {
us.RecentUsed[url]++
// 如果某个URL使用次数达到上限从RecentUsed中移除
if us.RecentUsed[url] == noRepeatCount {
delete(us.RecentUsed, url)
}
return url
}
}
// 如果所有URL都被最近使用过重置RecentUsed并返回第一个URL
us.RecentUsed = make(map[string]int)
return us.URLs[0]
}
func init() {
// 确保数据目录存在
if err := os.MkdirAll("data", 0755); err != nil {
log.Fatal("Failed to create data directory:", err)
}
}
func main() {
source := rand.NewSource(time.Now().UnixNano())
rng = rand.New(source)
logging.SetupLogging()
statsManager = stats.NewStatsManager("data/stats.json")
// 设置优雅关闭
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
log.Println("Server is shutting down...")
// 关闭统计管理器,确保统计数据被保存
statsManager.Shutdown()
log.Println("Stats manager shutdown completed")
os.Exit(0)
}()
if err := loadCSVPaths(); err != nil {
log.Fatal("Failed to load CSV paths:", err)
}
// 设置静态文件服务
fs := http.FileServer(http.Dir("./public"))
http.Handle("/", fs)
// 设置 API 路由
http.HandleFunc("/pic/", handleAPIRequest)
http.HandleFunc("/video/", handleAPIRequest)
http.HandleFunc("/stats", handleStats)
log.Printf("Server starting on %s...\n", port)
if err := http.ListenAndServe(port, nil); err != nil {
log.Fatal(err)
}
}
func loadCSVPaths() error {
jsonPath := filepath.Join("public", "url.json")
log.Printf("Attempting to read file: %s", jsonPath)
data, err := os.ReadFile(jsonPath)
if err != nil {
return fmt.Errorf("failed to read url.json: %w", err)
}
var result map[string]map[string]string
if err := json.Unmarshal(data, &result); err != nil {
return fmt.Errorf("failed to unmarshal url.json: %w", err)
}
mu.Lock()
csvPathsCache = result
lastFetchTime = time.Now()
mu.Unlock()
log.Println("CSV paths loaded from url.json")
return nil
}
func getCSVContent(path string) (*URLSelector, error) {
mu.RLock()
selector, exists := csvCache[path]
mu.RUnlock()
if exists {
return selector, nil
}
fullPath := filepath.Join("public", path)
log.Printf("尝试读取文件: %s", fullPath)
fileContent, err := os.ReadFile(fullPath)
if err != nil {
return nil, fmt.Errorf("读取 CSV 内容时出错: %w", err)
}
lines := strings.Split(string(fileContent), "\n")
uniqueURLs := make(map[string]bool)
var fileArray []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed != "" && !strings.HasPrefix(trimmed, "#") && !uniqueURLs[trimmed] {
fileArray = append(fileArray, trimmed)
uniqueURLs[trimmed] = true
}
}
selector = NewURLSelector(fileArray)
mu.Lock()
csvCache[path] = selector
mu.Unlock()
return selector, nil
}
func handleAPIRequest(w http.ResponseWriter, r *http.Request) {
start := time.Now()
realIP := utils.GetRealIP(r)
referer := r.Referer()
// 获取来源域名
var sourceDomain string
if referer != "" {
if parsedURL, err := url.Parse(referer); err == nil {
sourceDomain = parsedURL.Hostname()
}
}
if sourceDomain == "" {
sourceDomain = "direct"
}
if time.Since(lastFetchTime) > cacheDuration {
if err := loadCSVPaths(); err != nil {
http.Error(w, "无法加载 CSV 路径", http.StatusInternalServerError)
log.Printf("加载 CSV 路径时出错: %v", err)
return
}
}
path := strings.TrimPrefix(r.URL.Path, "/")
pathSegments := strings.Split(path, "/")
if len(pathSegments) < 2 {
http.NotFound(w, r)
return
}
prefix := pathSegments[0]
suffix := pathSegments[1]
mu.RLock()
csvPath, ok := csvPathsCache[prefix][suffix]
mu.RUnlock()
if !ok {
http.NotFound(w, r)
return
}
selector, err := getCSVContent(csvPath)
if err != nil {
http.Error(w, "无法获取 CSV 内容", http.StatusInternalServerError)
log.Printf("获取 CSV 内容时出错: %v", err)
return
}
if len(selector.URLs) == 0 {
http.Error(w, "无可用内容", http.StatusNotFound)
return
}
randomURL := selector.GetRandomURL()
// 记录统计
endpoint := fmt.Sprintf("%s/%s", prefix, suffix)
statsManager.IncrementCalls(endpoint)
duration := time.Since(start)
log.Printf("请求:%s %s来自 %s -来源:%s -持续时间:%v -重定向至:%s",
r.Method, r.URL.Path, realIP, sourceDomain, duration, randomURL)
http.Redirect(w, r, randomURL, http.StatusFound)
}
// 统计API处理函数
func handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
stats := statsManager.GetStats()
if err := json.NewEncoder(w).Encode(stats); err != nil {
http.Error(w, "Error encoding stats", http.StatusInternalServerError)
log.Printf("Error encoding stats: %v", err)
}
}