From 39155ac1bb39521889327a591b53f5185d401daa Mon Sep 17 00:00:00 2001 From: wood chen Date: Sat, 26 Oct 2024 15:29:21 +0800 Subject: [PATCH] feat(api, stats): add graceful shutdown and periodic stats saving --- main.go | 32 +++++++++-- stats/stats.go | 149 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 141 insertions(+), 40 deletions(-) diff --git a/main.go b/main.go index 79db0aa..8e94d5e 100644 --- a/main.go +++ b/main.go @@ -10,10 +10,12 @@ import ( "net/http" "net/url" "os" + "os/signal" "path/filepath" "random-api-go/stats" "strings" "sync" + "syscall" "time" ) @@ -30,10 +32,9 @@ var ( csvCache = make(map[string]*URLSelector) mu sync.RWMutex rng *rand.Rand + statsManager *stats.StatsManager ) -var statsManager *stats.StatsManager - type URLSelector struct { URLs []string CurrentIndex int @@ -95,7 +96,22 @@ func main() { rng = rand.New(source) setupLogging() - statsManager = stats.NewStatsManager("data/stats.json") + statsManager = stats.NewStatsManager("stats.json") + + // 设置优雅关闭 + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + log.Println("Server is shutting down...") + + // 关闭统计管理器,确保统计数据被保存 + statsManager.Shutdown() + log.Println("Stats manager shutdown completed") + + os.Exit(0) + }() if err := loadCSVPaths(); err != nil { log.Fatal("Failed to load CSV paths:", err) @@ -108,10 +124,9 @@ func main() { // 设置 API 路由 http.HandleFunc("/pic/", handleAPIRequest) http.HandleFunc("/video/", handleAPIRequest) - // 添加统计API路由 http.HandleFunc("/stats", handleStats) - log.Printf("Listening on %s...\n", port) + log.Printf("Server starting on %s...\n", port) if err := http.ListenAndServe(port, nil); err != nil { log.Fatal(err) } @@ -276,7 +291,12 @@ func handleAPIRequest(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, randomURL, http.StatusFound) } +// 统计API处理函数 func handleStats(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") stats := statsManager.GetStats() - json.NewEncoder(w).Encode(stats) + if err := json.NewEncoder(w).Encode(stats); err != nil { + http.Error(w, "Error encoding stats", http.StatusInternalServerError) + log.Printf("Error encoding stats: %v", err) + } } diff --git a/stats/stats.go b/stats/stats.go index 78e59ea..4035907 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -14,21 +14,100 @@ type EndpointStats struct { } type StatsManager struct { - Stats map[string]*EndpointStats `json:"stats"` - mu sync.RWMutex - filepath string + Stats map[string]*EndpointStats `json:"stats"` + mu sync.RWMutex + filepath string + isDirty bool + lastSaveTime time.Time + saveInterval time.Duration + minSaveInterval time.Duration + shutdown chan struct{} + wg sync.WaitGroup // 添加 WaitGroup 用于优雅关闭 } func NewStatsManager(filepath string) *StatsManager { sm := &StatsManager{ - Stats: make(map[string]*EndpointStats), - filepath: filepath, + Stats: make(map[string]*EndpointStats), + filepath: filepath, + saveInterval: 3 * time.Second, + minSaveInterval: 1 * time.Second, + lastSaveTime: time.Now(), + shutdown: make(chan struct{}), } + sm.LoadStats() + + sm.wg.Add(2) // 为两个goroutine添加计数 go sm.startDailyReset() + go sm.periodicSave() + return sm } +func (sm *StatsManager) periodicSave() { + defer sm.wg.Done() + ticker := time.NewTicker(sm.saveInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + sm.mu.Lock() + if sm.isDirty && time.Since(sm.lastSaveTime) >= sm.minSaveInterval { + sm.saveStatsLocked() + sm.isDirty = false + sm.lastSaveTime = time.Now() + } + sm.mu.Unlock() + + case <-sm.shutdown: + sm.mu.Lock() + if sm.isDirty { + sm.saveStatsLocked() + sm.lastSaveTime = time.Now() + } + sm.mu.Unlock() + return + } + } +} + +func (sm *StatsManager) startDailyReset() { + defer sm.wg.Done() + for { + now := time.Now() + next := now.Add(24 * time.Hour) + next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, next.Location()) + duration := next.Sub(now) + + select { + case <-time.After(duration): + sm.mu.Lock() + for _, stats := range sm.Stats { + stats.TodayCalls = 0 + stats.LastResetDate = time.Now().Format("2006-01-02") + } + sm.isDirty = true + sm.mu.Unlock() + case <-sm.shutdown: + return + } + } +} + +// 优雅关闭 +func (sm *StatsManager) Shutdown() { + close(sm.shutdown) // 通知所有goroutine关闭 + sm.wg.Wait() // 等待所有goroutine完成 + + // 最后一次保存 + sm.mu.Lock() + if sm.isDirty { + sm.saveStatsLocked() + } + sm.mu.Unlock() +} + func (sm *StatsManager) IncrementCalls(endpoint string) { sm.mu.Lock() defer sm.mu.Unlock() @@ -41,22 +120,10 @@ func (sm *StatsManager) IncrementCalls(endpoint string) { sm.Stats[endpoint].TotalCalls++ sm.Stats[endpoint].TodayCalls++ - - // 异步保存统计数据 - go sm.SaveStats() + sm.isDirty = true } -func (sm *StatsManager) GetStats() map[string]*EndpointStats { - sm.mu.RLock() - defer sm.mu.RUnlock() - - return sm.Stats -} - -func (sm *StatsManager) SaveStats() error { - sm.mu.RLock() - defer sm.mu.RUnlock() - +func (sm *StatsManager) saveStatsLocked() error { data, err := json.MarshalIndent(sm, "", " ") if err != nil { return err @@ -64,6 +131,18 @@ func (sm *StatsManager) SaveStats() error { return os.WriteFile(sm.filepath, data, 0644) } +func (sm *StatsManager) ForceSave() error { + sm.mu.Lock() + defer sm.mu.Unlock() + + err := sm.saveStatsLocked() + if err == nil { + sm.isDirty = false + sm.lastSaveTime = time.Now() + } + return err +} + func (sm *StatsManager) LoadStats() error { data, err := os.ReadFile(sm.filepath) if err != nil { @@ -76,22 +155,24 @@ func (sm *StatsManager) LoadStats() error { return json.Unmarshal(data, sm) } -func (sm *StatsManager) startDailyReset() { - for { - now := time.Now() - next := now.Add(24 * time.Hour) - next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, next.Location()) - duration := next.Sub(now) +func (sm *StatsManager) GetStats() map[string]*EndpointStats { + sm.mu.RLock() + defer sm.mu.RUnlock() - time.Sleep(duration) - - sm.mu.Lock() - for _, stats := range sm.Stats { - stats.TodayCalls = 0 - stats.LastResetDate = time.Now().Format("2006-01-02") + statsCopy := make(map[string]*EndpointStats) + for k, v := range sm.Stats { + statsCopy[k] = &EndpointStats{ + TotalCalls: v.TotalCalls, + TodayCalls: v.TodayCalls, + LastResetDate: v.LastResetDate, } - sm.mu.Unlock() - - sm.SaveStats() } + + return statsCopy +} + +func (sm *StatsManager) LastSaveTime() time.Time { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.lastSaveTime }