feat(api, stats): add graceful shutdown and periodic stats saving

This commit is contained in:
wood chen 2024-10-26 15:29:21 +08:00
parent 09e4d8ccf7
commit 39155ac1bb
2 changed files with 141 additions and 40 deletions

32
main.go
View File

@ -10,10 +10,12 @@ import (
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"random-api-go/stats"
"strings"
"sync"
"syscall"
"time"
)
@ -30,10 +32,9 @@ var (
csvCache = make(map[string]*URLSelector)
mu sync.RWMutex
rng *rand.Rand
statsManager *stats.StatsManager
)
var statsManager *stats.StatsManager
type URLSelector struct {
URLs []string
CurrentIndex int
@ -95,7 +96,22 @@ func main() {
rng = rand.New(source)
setupLogging()
statsManager = stats.NewStatsManager("data/stats.json")
statsManager = stats.NewStatsManager("stats.json")
// 设置优雅关闭
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
log.Println("Server is shutting down...")
// 关闭统计管理器,确保统计数据被保存
statsManager.Shutdown()
log.Println("Stats manager shutdown completed")
os.Exit(0)
}()
if err := loadCSVPaths(); err != nil {
log.Fatal("Failed to load CSV paths:", err)
@ -108,10 +124,9 @@ func main() {
// 设置 API 路由
http.HandleFunc("/pic/", handleAPIRequest)
http.HandleFunc("/video/", handleAPIRequest)
// 添加统计API路由
http.HandleFunc("/stats", handleStats)
log.Printf("Listening on %s...\n", port)
log.Printf("Server starting on %s...\n", port)
if err := http.ListenAndServe(port, nil); err != nil {
log.Fatal(err)
}
@ -276,7 +291,12 @@ func handleAPIRequest(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, randomURL, http.StatusFound)
}
// 统计API处理函数
func handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
stats := statsManager.GetStats()
json.NewEncoder(w).Encode(stats)
if err := json.NewEncoder(w).Encode(stats); err != nil {
http.Error(w, "Error encoding stats", http.StatusInternalServerError)
log.Printf("Error encoding stats: %v", err)
}
}

View File

@ -14,21 +14,100 @@ type EndpointStats struct {
}
type StatsManager struct {
Stats map[string]*EndpointStats `json:"stats"`
mu sync.RWMutex
filepath string
Stats map[string]*EndpointStats `json:"stats"`
mu sync.RWMutex
filepath string
isDirty bool
lastSaveTime time.Time
saveInterval time.Duration
minSaveInterval time.Duration
shutdown chan struct{}
wg sync.WaitGroup // 添加 WaitGroup 用于优雅关闭
}
func NewStatsManager(filepath string) *StatsManager {
sm := &StatsManager{
Stats: make(map[string]*EndpointStats),
filepath: filepath,
Stats: make(map[string]*EndpointStats),
filepath: filepath,
saveInterval: 3 * time.Second,
minSaveInterval: 1 * time.Second,
lastSaveTime: time.Now(),
shutdown: make(chan struct{}),
}
sm.LoadStats()
sm.wg.Add(2) // 为两个goroutine添加计数
go sm.startDailyReset()
go sm.periodicSave()
return sm
}
func (sm *StatsManager) periodicSave() {
defer sm.wg.Done()
ticker := time.NewTicker(sm.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.mu.Lock()
if sm.isDirty && time.Since(sm.lastSaveTime) >= sm.minSaveInterval {
sm.saveStatsLocked()
sm.isDirty = false
sm.lastSaveTime = time.Now()
}
sm.mu.Unlock()
case <-sm.shutdown:
sm.mu.Lock()
if sm.isDirty {
sm.saveStatsLocked()
sm.lastSaveTime = time.Now()
}
sm.mu.Unlock()
return
}
}
}
func (sm *StatsManager) startDailyReset() {
defer sm.wg.Done()
for {
now := time.Now()
next := now.Add(24 * time.Hour)
next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, next.Location())
duration := next.Sub(now)
select {
case <-time.After(duration):
sm.mu.Lock()
for _, stats := range sm.Stats {
stats.TodayCalls = 0
stats.LastResetDate = time.Now().Format("2006-01-02")
}
sm.isDirty = true
sm.mu.Unlock()
case <-sm.shutdown:
return
}
}
}
// 优雅关闭
func (sm *StatsManager) Shutdown() {
close(sm.shutdown) // 通知所有goroutine关闭
sm.wg.Wait() // 等待所有goroutine完成
// 最后一次保存
sm.mu.Lock()
if sm.isDirty {
sm.saveStatsLocked()
}
sm.mu.Unlock()
}
func (sm *StatsManager) IncrementCalls(endpoint string) {
sm.mu.Lock()
defer sm.mu.Unlock()
@ -41,22 +120,10 @@ func (sm *StatsManager) IncrementCalls(endpoint string) {
sm.Stats[endpoint].TotalCalls++
sm.Stats[endpoint].TodayCalls++
// 异步保存统计数据
go sm.SaveStats()
sm.isDirty = true
}
func (sm *StatsManager) GetStats() map[string]*EndpointStats {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.Stats
}
func (sm *StatsManager) SaveStats() error {
sm.mu.RLock()
defer sm.mu.RUnlock()
func (sm *StatsManager) saveStatsLocked() error {
data, err := json.MarshalIndent(sm, "", " ")
if err != nil {
return err
@ -64,6 +131,18 @@ func (sm *StatsManager) SaveStats() error {
return os.WriteFile(sm.filepath, data, 0644)
}
func (sm *StatsManager) ForceSave() error {
sm.mu.Lock()
defer sm.mu.Unlock()
err := sm.saveStatsLocked()
if err == nil {
sm.isDirty = false
sm.lastSaveTime = time.Now()
}
return err
}
func (sm *StatsManager) LoadStats() error {
data, err := os.ReadFile(sm.filepath)
if err != nil {
@ -76,22 +155,24 @@ func (sm *StatsManager) LoadStats() error {
return json.Unmarshal(data, sm)
}
func (sm *StatsManager) startDailyReset() {
for {
now := time.Now()
next := now.Add(24 * time.Hour)
next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, next.Location())
duration := next.Sub(now)
func (sm *StatsManager) GetStats() map[string]*EndpointStats {
sm.mu.RLock()
defer sm.mu.RUnlock()
time.Sleep(duration)
sm.mu.Lock()
for _, stats := range sm.Stats {
stats.TodayCalls = 0
stats.LastResetDate = time.Now().Format("2006-01-02")
statsCopy := make(map[string]*EndpointStats)
for k, v := range sm.Stats {
statsCopy[k] = &EndpointStats{
TotalCalls: v.TotalCalls,
TodayCalls: v.TodayCalls,
LastResetDate: v.LastResetDate,
}
sm.mu.Unlock()
sm.SaveStats()
}
return statsCopy
}
func (sm *StatsManager) LastSaveTime() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.lastSaveTime
}