@@ -506,6 +626,8 @@ var metricsTemplate = `
+
+
diff --git a/internal/metrics/collector.go b/internal/metrics/collector.go
index 79946eb..b36848a 100644
--- a/internal/metrics/collector.go
+++ b/internal/metrics/collector.go
@@ -18,12 +18,17 @@ import (
)
type Collector struct {
- startTime time.Time
- activeRequests int64
- totalRequests int64
- totalErrors int64
- totalBytes atomic.Int64
- latencySum atomic.Int64
+ startTime time.Time
+ activeRequests int64
+ totalRequests int64
+ totalErrors int64
+ totalBytes atomic.Int64
+ latencySum atomic.Int64
+ persistentStats struct {
+ totalRequests atomic.Int64
+ totalErrors atomic.Int64
+ totalBytes atomic.Int64
+ }
pathStats sync.Map
refererStats sync.Map
statusStats [6]atomic.Int64
@@ -55,6 +60,15 @@ func InitCollector(dbPath string, config *config.Config) error {
db: db,
}
+ // 加载历史数据
+ if lastMetrics, err := db.GetLastMetrics(); err == nil && lastMetrics != nil {
+ globalCollector.persistentStats.totalRequests.Store(lastMetrics.TotalRequests)
+ globalCollector.persistentStats.totalErrors.Store(lastMetrics.TotalErrors)
+ globalCollector.persistentStats.totalBytes.Store(lastMetrics.TotalBytes)
+ log.Printf("Loaded historical metrics: requests=%d, errors=%d, bytes=%d",
+ lastMetrics.TotalRequests, lastMetrics.TotalErrors, lastMetrics.TotalBytes)
+ }
+
globalCollector.cache = cache.NewCache(constants.CacheTTL)
globalCollector.monitor = monitor.NewMonitor()
@@ -95,13 +109,28 @@ func InitCollector(dbPath string, config *config.Config) error {
// 设置程序退出时的处理
utils.SetupCloseHandler(func() {
log.Println("Saving final metrics before shutdown...")
+ // 确保所有正在进行的操作完成
+ time.Sleep(time.Second)
+
stats := globalCollector.GetStats()
- if err := db.SaveFullMetrics(stats); err != nil {
+ if err := db.SaveMetrics(stats); err != nil {
log.Printf("Error saving final metrics: %v", err)
} else {
- log.Printf("Final metrics saved successfully")
+ log.Printf("Basic metrics saved successfully")
}
+
+ // 保存完整统计数据
+ if err := db.SaveFullMetrics(stats); err != nil {
+ log.Printf("Error saving full metrics: %v", err)
+ } else {
+ log.Printf("Full metrics saved successfully")
+ }
+
+ // 等待数据写入完成
+ time.Sleep(time.Second)
+
db.Close()
+ log.Println("Database closed successfully")
})
globalCollector.statsPool = sync.Pool{
@@ -220,20 +249,33 @@ func (c *Collector) GetStats() map[string]interface{} {
// 确保所有字段都被初始化
stats := make(map[string]interface{})
- // 基础指标
+ // 基础指标 - 合并当前会话和持久化的数据
+ currentRequests := atomic.LoadInt64(&c.totalRequests)
+ currentErrors := atomic.LoadInt64(&c.totalErrors)
+ currentBytes := c.totalBytes.Load()
+
+ totalRequests := currentRequests + c.persistentStats.totalRequests.Load()
+ totalErrors := currentErrors + c.persistentStats.totalErrors.Load()
+ totalBytes := currentBytes + c.persistentStats.totalBytes.Load()
+
+ // 计算每秒指标
+ uptime := time.Since(c.startTime).Seconds()
+ stats["requests_per_second"] = float64(currentRequests) / Max(uptime, 1)
+ stats["bytes_per_second"] = float64(currentBytes) / Max(uptime, 1)
+
stats["active_requests"] = atomic.LoadInt64(&c.activeRequests)
- stats["total_requests"] = atomic.LoadInt64(&c.totalRequests)
- stats["total_errors"] = atomic.LoadInt64(&c.totalErrors)
- stats["total_bytes"] = c.totalBytes.Load()
+ stats["total_requests"] = totalRequests
+ stats["total_errors"] = totalErrors
+ stats["total_bytes"] = totalBytes
// 系统指标
stats["num_goroutine"] = runtime.NumGoroutine()
stats["memory_usage"] = FormatBytes(m.Alloc)
// 延迟指标
- totalRequests := atomic.LoadInt64(&c.totalRequests)
- if totalRequests > 0 {
- stats["avg_latency"] = c.latencySum.Load() / totalRequests
+ currentTotalRequests := atomic.LoadInt64(&c.totalRequests)
+ if currentTotalRequests > 0 {
+ stats["avg_latency"] = c.latencySum.Load() / currentTotalRequests
} else {
stats["avg_latency"] = int64(0)
}
@@ -367,3 +409,35 @@ func Max(a, b float64) float64 {
func (c *Collector) GetDB() *models.MetricsDB {
return c.db
}
+
+func (c *Collector) SaveMetrics(stats map[string]interface{}) error {
+ // 更新持久化数据
+ c.persistentStats.totalRequests.Store(stats["total_requests"].(int64))
+ c.persistentStats.totalErrors.Store(stats["total_errors"].(int64))
+ c.persistentStats.totalBytes.Store(stats["total_bytes"].(int64))
+
+ // 重置当前会话计数器
+ atomic.StoreInt64(&c.totalRequests, 0)
+ atomic.StoreInt64(&c.totalErrors, 0)
+ c.totalBytes.Store(0)
+ c.latencySum.Store(0)
+
+ // 重置状态码统计
+ for i := range c.statusStats {
+ c.statusStats[i].Store(0)
+ }
+
+ // 重置路径统计
+ c.pathStats.Range(func(key, _ interface{}) bool {
+ c.pathStats.Delete(key)
+ return true
+ })
+
+ // 重置引用来源统计
+ c.refererStats.Range(func(key, _ interface{}) bool {
+ c.refererStats.Delete(key)
+ return true
+ })
+
+ return c.db.SaveMetrics(stats)
+}
diff --git a/internal/models/metrics.go b/internal/models/metrics.go
index 6563e21..da87c2a 100644
--- a/internal/models/metrics.go
+++ b/internal/models/metrics.go
@@ -2,8 +2,10 @@ package models
import (
"database/sql"
+ "fmt"
"log"
"proxy-go/internal/constants"
+ "strings"
"sync/atomic"
"time"
@@ -47,12 +49,31 @@ type MetricsDB struct {
DB *sql.DB
}
+type PerformanceMetrics struct {
+ Timestamp string `json:"timestamp"`
+ AvgResponseTime int64 `json:"avg_response_time"`
+ RequestsPerSecond float64 `json:"requests_per_second"`
+ BytesPerSecond float64 `json:"bytes_per_second"`
+}
+
func NewMetricsDB(dbPath string) (*MetricsDB, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, err
}
+ // 设置连接池参数
+ db.SetMaxOpenConns(1) // SQLite 只支持一个写连接
+ db.SetMaxIdleConns(1)
+ db.SetConnMaxLifetime(time.Hour)
+
+ // 设置数据库优化参数
+ db.Exec("PRAGMA busy_timeout = 5000") // 设置忙等待超时
+ db.Exec("PRAGMA journal_mode = WAL") // 使用 WAL 模式提高并发性能
+ db.Exec("PRAGMA synchronous = NORMAL") // 在保证安全的前提下提高性能
+ db.Exec("PRAGMA cache_size = -2000") // 使用2MB缓存
+ db.Exec("PRAGMA temp_store = MEMORY") // 临时表使用内存
+
// 创建必要的表
if err := initTables(db); err != nil {
db.Close()
@@ -194,47 +215,82 @@ func initTables(db *sql.DB) error {
// 定期清理旧数据
func cleanupRoutine(db *sql.DB) {
+ // 避免在启动时就立即清理
+ time.Sleep(5 * time.Minute)
+
ticker := time.NewTicker(constants.CleanupInterval)
for range ticker.C {
- // 开始事务
+ start := time.Now()
+ var totalDeleted int64
+
+ // 检查数据库大小
+ var dbSize int64
+ row := db.QueryRow("SELECT page_count * page_size FROM pragma_page_count, pragma_page_size")
+ row.Scan(&dbSize)
+ log.Printf("Current database size: %s", FormatBytes(uint64(dbSize)))
+
tx, err := db.Begin()
if err != nil {
log.Printf("Error starting cleanup transaction: %v", err)
continue
}
- // 删除超过保留期限的数据
+ // 优化清理性能
+ tx.Exec("PRAGMA synchronous = NORMAL")
+ tx.Exec("PRAGMA journal_mode = WAL")
+ tx.Exec("PRAGMA temp_store = MEMORY")
+ tx.Exec("PRAGMA cache_size = -2000")
+
+ // 先清理索引
+ tx.Exec("ANALYZE")
+ tx.Exec("PRAGMA optimize")
+
cutoff := time.Now().Add(-constants.DataRetention)
- _, err = tx.Exec(`DELETE FROM metrics_history WHERE timestamp < ?`, cutoff)
- if err != nil {
- tx.Rollback()
- log.Printf("Error cleaning metrics_history: %v", err)
- continue
+ tables := []string{
+ "metrics_history",
+ "status_stats",
+ "path_stats",
+ "performance_metrics",
+ "status_code_history",
+ "popular_paths_history",
+ "referer_history",
}
- _, err = tx.Exec(`DELETE FROM status_stats WHERE timestamp < ?`, cutoff)
- if err != nil {
- tx.Rollback()
- log.Printf("Error cleaning status_stats: %v", err)
- continue
+ for _, table := range tables {
+ // 使用批量删除提高性能
+ for {
+ result, err := tx.Exec(`DELETE FROM `+table+` WHERE timestamp < ? LIMIT 1000`, cutoff)
+ if err != nil {
+ tx.Rollback()
+ log.Printf("Error cleaning %s: %v", table, err)
+ break
+ }
+ rows, _ := result.RowsAffected()
+ totalDeleted += rows
+ if rows < 1000 {
+ break
+ }
+ }
}
- _, err = tx.Exec(`DELETE FROM path_stats WHERE timestamp < ?`, cutoff)
- if err != nil {
- tx.Rollback()
- log.Printf("Error cleaning path_stats: %v", err)
- continue
- }
-
- // 提交事务
if err := tx.Commit(); err != nil {
log.Printf("Error committing cleanup transaction: %v", err)
} else {
- log.Printf("Successfully cleaned up old metrics data")
+ log.Printf("Cleaned up %d old records in %v, freed %s",
+ totalDeleted, time.Since(start),
+ FormatBytes(uint64(dbSize-getDBSize(db))))
}
}
}
+// 获取数据库大小
+func getDBSize(db *sql.DB) int64 {
+ var size int64
+ row := db.QueryRow("SELECT page_count * page_size FROM pragma_page_count, pragma_page_size")
+ row.Scan(&size)
+ return size
+}
+
func (db *MetricsDB) SaveMetrics(stats map[string]interface{}) error {
tx, err := db.DB.Begin()
if err != nil {
@@ -291,6 +347,21 @@ func (db *MetricsDB) SaveMetrics(stats map[string]interface{}) error {
}
}
+ // 同时保存性能指标
+ _, err = tx.Exec(`
+ INSERT INTO performance_metrics (
+ avg_response_time,
+ requests_per_second,
+ bytes_per_second
+ ) VALUES (?, ?, ?)`,
+ stats["avg_latency"],
+ stats["requests_per_second"],
+ stats["bytes_per_second"],
+ )
+ if err != nil {
+ return err
+ }
+
return tx.Commit()
}
@@ -299,9 +370,18 @@ func (db *MetricsDB) Close() error {
}
func (db *MetricsDB) GetRecentMetrics(hours int) ([]HistoricalMetrics, error) {
+ // 设置查询优化参数
+ db.DB.Exec("PRAGMA temp_store = MEMORY")
+ db.DB.Exec("PRAGMA cache_size = -4000") // 使用4MB缓存
+ db.DB.Exec("PRAGMA mmap_size = 268435456") // 使用256MB内存映射
+
var interval string
if hours <= 24 {
- interval = "%Y-%m-%d %H:%M:00" // 按分钟分组
+ if hours <= 1 {
+ interval = "%Y-%m-%d %H:%M:00" // 按分钟分组
+ } else {
+ interval = "%Y-%m-%d %H:00:00" // 按小时分组
+ }
} else if hours <= 168 {
interval = "%Y-%m-%d %H:00:00" // 按小时分组
} else {
@@ -411,6 +491,26 @@ func (db *MetricsDB) SaveFullMetrics(stats map[string]interface{}) error {
}
defer tx.Rollback()
+ // 开始时记录数据库大小
+ startSize := getDBSize(db.DB)
+
+ // 优化写入性能
+ tx.Exec("PRAGMA synchronous = NORMAL")
+ tx.Exec("PRAGMA journal_mode = WAL")
+ tx.Exec("PRAGMA temp_store = MEMORY")
+ tx.Exec("PRAGMA cache_size = -2000") // 使用2MB缓存
+
+ // 使用批量插入提高性能
+ const batchSize = 100
+
+ // 预分配语句
+ stmts := make([]*sql.Stmt, 0, 4)
+ defer func() {
+ for _, stmt := range stmts {
+ stmt.Close()
+ }
+ }()
+
// 保存性能指标
_, err = tx.Exec(`
INSERT INTO performance_metrics (
@@ -426,26 +526,30 @@ func (db *MetricsDB) SaveFullMetrics(stats map[string]interface{}) error {
return err
}
+ // 使用事务提高写入性能
+ tx.Exec("PRAGMA synchronous = OFF")
+ tx.Exec("PRAGMA journal_mode = MEMORY")
+
// 保存状态码统计
statusStats := stats["status_code_stats"].(map[string]int64)
- stmt, err := tx.Prepare(`
- INSERT INTO status_code_history (status_group, count)
- VALUES (?, ?)
- `)
- if err != nil {
- return err
- }
- defer stmt.Close()
+ values := make([]string, 0, len(statusStats))
+ args := make([]interface{}, 0, len(statusStats)*2)
for group, count := range statusStats {
- if _, err := stmt.Exec(group, count); err != nil {
- return err
- }
+ values = append(values, "(?, ?)")
+ args = append(args, group, count)
+ }
+
+ query := "INSERT INTO status_code_history (status_group, count) VALUES " +
+ strings.Join(values, ",")
+
+ if _, err := tx.Exec(query, args...); err != nil {
+ return err
}
// 保存热门路径
pathStats := stats["top_paths"].([]PathMetrics)
- stmt, err = tx.Prepare(`
+ pathStmt, err := tx.Prepare(`
INSERT INTO popular_paths_history (
path, request_count, error_count, avg_latency, bytes_transferred
) VALUES (?, ?, ?, ?, ?)
@@ -453,9 +557,10 @@ func (db *MetricsDB) SaveFullMetrics(stats map[string]interface{}) error {
if err != nil {
return err
}
+ defer pathStmt.Close()
for _, p := range pathStats {
- if _, err := stmt.Exec(
+ if _, err := pathStmt.Exec(
p.Path, p.RequestCount, p.ErrorCount,
p.AvgLatency, p.BytesTransferred,
); err != nil {
@@ -465,19 +570,109 @@ func (db *MetricsDB) SaveFullMetrics(stats map[string]interface{}) error {
// 保存引用来源
refererStats := stats["top_referers"].([]PathMetrics)
- stmt, err = tx.Prepare(`
+ refererStmt, err := tx.Prepare(`
INSERT INTO referer_history (referer, request_count)
VALUES (?, ?)
`)
if err != nil {
return err
}
+ defer refererStmt.Close()
for _, r := range refererStats {
- if _, err := stmt.Exec(r.Path, r.RequestCount); err != nil {
+ if _, err := refererStmt.Exec(r.Path, r.RequestCount); err != nil {
return err
}
}
- return tx.Commit()
+ if err := tx.Commit(); err != nil {
+ return err
+ }
+
+ // 记录写入的数据量
+ endSize := getDBSize(db.DB)
+ log.Printf("Saved metrics: wrote %s to database",
+ FormatBytes(uint64(endSize-startSize)))
+
+ return nil
+}
+
+func (db *MetricsDB) GetLastMetrics() (*HistoricalMetrics, error) {
+ row := db.DB.QueryRow(`
+ SELECT
+ total_requests,
+ total_errors,
+ total_bytes,
+ avg_latency
+ FROM metrics_history
+ ORDER BY timestamp DESC
+ LIMIT 1
+ `)
+
+ var metrics HistoricalMetrics
+ err := row.Scan(
+ &metrics.TotalRequests,
+ &metrics.TotalErrors,
+ &metrics.TotalBytes,
+ &metrics.AvgLatency,
+ )
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ return &metrics, nil
+}
+
+func (db *MetricsDB) GetRecentPerformanceMetrics(hours int) ([]PerformanceMetrics, error) {
+ rows, err := db.DB.Query(`
+ SELECT
+ strftime('%Y-%m-%d %H:%M:00', timestamp, 'localtime') as ts,
+ AVG(avg_response_time) as avg_response_time,
+ AVG(requests_per_second) as requests_per_second,
+ AVG(bytes_per_second) as bytes_per_second
+ FROM performance_metrics
+ WHERE timestamp >= datetime('now', '-' || ? || ' hours', 'localtime')
+ GROUP BY ts
+ ORDER BY ts DESC
+ `, hours)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var metrics []PerformanceMetrics
+ for rows.Next() {
+ var m PerformanceMetrics
+ err := rows.Scan(
+ &m.Timestamp,
+ &m.AvgResponseTime,
+ &m.RequestsPerSecond,
+ &m.BytesPerSecond,
+ )
+ if err != nil {
+ return nil, err
+ }
+ metrics = append(metrics, m)
+ }
+
+ return metrics, rows.Err()
+}
+
+// FormatBytes 格式化字节大小
+func FormatBytes(bytes uint64) string {
+ const (
+ MB = 1024 * 1024
+ KB = 1024
+ )
+
+ switch {
+ case bytes >= MB:
+ return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
+ case bytes >= KB:
+ return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
+ default:
+ return fmt.Sprintf("%d Bytes", bytes)
+ }
}
diff --git a/internal/utils/signal.go b/internal/utils/signal.go
index 61f985a..95e3a32 100644
--- a/internal/utils/signal.go
+++ b/internal/utils/signal.go
@@ -3,15 +3,30 @@ package utils
import (
"os"
"os/signal"
+ "sync"
"syscall"
)
func SetupCloseHandler(callback func()) {
- c := make(chan os.Signal, 2)
+ c := make(chan os.Signal, 1)
+ done := make(chan bool, 1)
+ var once sync.Once
+
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
- callback()
- os.Exit(0)
+ once.Do(func() {
+ callback()
+ done <- true
+ })
+ }()
+
+ go func() {
+ select {
+ case <-done:
+ os.Exit(0)
+ case <-c:
+ os.Exit(1)
+ }
}()
}