diff --git a/go.mod b/go.mod index a9631cf..88030b3 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,7 @@ go 1.23.1 require ( github.com/andybalholm/brotli v1.1.1 - golang.org/x/time v0.9.0 + golang.org/x/net v0.35.0 ) -require ( - golang.org/x/net v0.35.0 // indirect - golang.org/x/text v0.22.0 // indirect -) +require golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index a9c480b..159688d 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,3 @@ golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= -golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= -golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/internal/cache/manager.go b/internal/cache/manager.go new file mode 100644 index 0000000..84050e4 --- /dev/null +++ b/internal/cache/manager.go @@ -0,0 +1,433 @@ +package cache + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// CacheKey 用于标识缓存项的唯一键 +type CacheKey struct { + URL string + AcceptHeaders string + UserAgent string + VaryHeadersMap map[string]string // 存储 Vary 头部的值 +} + +// CacheItem 表示一个缓存项 +type CacheItem struct { + FilePath string + ContentType string + Size int64 + LastAccess time.Time + Hash string + ETag string + LastModified time.Time + CacheControl string + VaryHeaders []string + // 新增防穿透字段 + NegativeCache bool // 标记是否为空结果缓存 + AccessCount int64 // 访问计数 + CreatedAt time.Time +} + +// CacheStats 缓存统计信息 +type CacheStats struct { + TotalItems int `json:"total_items"` // 缓存项数量 + TotalSize int64 `json:"total_size"` // 总大小 + HitCount int64 `json:"hit_count"` // 命中次数 + MissCount int64 `json:"miss_count"` // 未命中次数 + HitRate float64 `json:"hit_rate"` // 命中率 + BytesSaved int64 `json:"bytes_saved"` // 节省的带宽 + Enabled bool `json:"enabled"` // 缓存开关状态 +} + +// CacheManager 缓存管理器 +type CacheManager struct { + cacheDir string + items sync.Map + maxAge time.Duration + cleanupTick time.Duration + maxCacheSize int64 + enabled atomic.Bool // 缓存开关 + hitCount atomic.Int64 // 命中计数 + missCount atomic.Int64 // 未命中计数 + bytesSaved atomic.Int64 // 节省的带宽 +} + +// NewCacheManager 创建新的缓存管理器 +func NewCacheManager(cacheDir string) (*CacheManager, error) { + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create cache directory: %v", err) + } + + cm := &CacheManager{ + cacheDir: cacheDir, + maxAge: 30 * time.Minute, + cleanupTick: 5 * time.Minute, + maxCacheSize: 10 * 1024 * 1024 * 1024, // 10GB + } + + cm.enabled.Store(true) // 默认启用缓存 + + // 启动清理协程 + go cm.cleanup() + + return cm, nil +} + +// GenerateCacheKey 生成缓存键 +func (cm *CacheManager) GenerateCacheKey(r *http.Request) CacheKey { + return CacheKey{ + URL: r.URL.String(), + AcceptHeaders: r.Header.Get("Accept"), + UserAgent: r.Header.Get("User-Agent"), + } +} + +// Get 获取缓存项 +func (cm *CacheManager) Get(key CacheKey, r *http.Request) (*CacheItem, bool, bool) { + // 如果缓存被禁用,直接返回未命中 + if !cm.enabled.Load() { + cm.missCount.Add(1) + return nil, false, false + } + + // 检查是否存在缓存项 + if value, ok := cm.items.Load(key); ok { + item := value.(*CacheItem) + + // 检查文件是否存在 + if _, err := os.Stat(item.FilePath); err != nil { + cm.items.Delete(key) + cm.missCount.Add(1) + return nil, false, false + } + + // 检查是否为负缓存(防止缓存穿透) + if item.NegativeCache { + // 如果访问次数较少且是负缓存,允许重新验证 + if item.AccessCount < 10 { + item.AccessCount++ + return nil, false, false + } + // 返回空结果,但标记为命中 + cm.hitCount.Add(1) + return nil, true, true + } + + // 检查 Vary 头部 + for _, varyHeader := range item.VaryHeaders { + if r.Header.Get(varyHeader) != key.VaryHeadersMap[varyHeader] { + cm.missCount.Add(1) + return nil, false, false + } + } + + // 处理条件请求 + ifNoneMatch := r.Header.Get("If-None-Match") + ifModifiedSince := r.Header.Get("If-Modified-Since") + + // ETag 匹配 + if ifNoneMatch != "" && item.ETag != "" { + if ifNoneMatch == item.ETag { + cm.hitCount.Add(1) + return item, true, true + } + } + + // Last-Modified 匹配 + if ifModifiedSince != "" && !item.LastModified.IsZero() { + if modifiedSince, err := time.Parse(time.RFC1123, ifModifiedSince); err == nil { + if !item.LastModified.After(modifiedSince) { + cm.hitCount.Add(1) + return item, true, true + } + } + } + + // 检查 Cache-Control + if item.CacheControl != "" { + if cm.isCacheExpired(item) { + cm.items.Delete(key) + cm.missCount.Add(1) + return nil, false, false + } + } + + // 更新访问统计 + item.LastAccess = time.Now() + item.AccessCount++ + cm.hitCount.Add(1) + cm.bytesSaved.Add(item.Size) + return item, true, false + } + + cm.missCount.Add(1) + return nil, false, false +} + +// isCacheExpired 检查缓存是否过期 +func (cm *CacheManager) isCacheExpired(item *CacheItem) bool { + if item.CacheControl == "" { + return false + } + + // 解析 max-age + if strings.Contains(item.CacheControl, "max-age=") { + parts := strings.Split(item.CacheControl, "max-age=") + if len(parts) > 1 { + maxAge := strings.Split(parts[1], ",")[0] + if seconds, err := strconv.Atoi(maxAge); err == nil { + return time.Since(item.CreatedAt) > time.Duration(seconds)*time.Second + } + } + } + + return false +} + +// Put 添加缓存项 +func (cm *CacheManager) Put(key CacheKey, resp *http.Response, body []byte) (*CacheItem, error) { + // 检查缓存控制头 + if !cm.shouldCache(resp) { + return nil, fmt.Errorf("response should not be cached") + } + + // 生成文件名 + hash := sha256.Sum256([]byte(fmt.Sprintf("%v-%v-%v-%v", key.URL, key.AcceptHeaders, key.UserAgent, time.Now().UnixNano()))) + fileName := hex.EncodeToString(hash[:]) + filePath := filepath.Join(cm.cacheDir, fileName) + + // 使用更安全的文件权限 + if err := os.WriteFile(filePath, body, 0600); err != nil { + return nil, fmt.Errorf("failed to write cache file: %v", err) + } + + // 计算内容哈希 + contentHash := sha256.Sum256(body) + + // 解析缓存控制头 + cacheControl := resp.Header.Get("Cache-Control") + lastModified := resp.Header.Get("Last-Modified") + etag := resp.Header.Get("ETag") + + var lastModifiedTime time.Time + if lastModified != "" { + if t, err := time.Parse(time.RFC1123, lastModified); err == nil { + lastModifiedTime = t + } + } + + // 处理 Vary 头部 + varyHeaders := strings.Split(resp.Header.Get("Vary"), ",") + for i, h := range varyHeaders { + varyHeaders[i] = strings.TrimSpace(h) + } + + item := &CacheItem{ + FilePath: filePath, + ContentType: resp.Header.Get("Content-Type"), + Size: int64(len(body)), + LastAccess: time.Now(), + Hash: hex.EncodeToString(contentHash[:]), + ETag: etag, + LastModified: lastModifiedTime, + CacheControl: cacheControl, + VaryHeaders: varyHeaders, + CreatedAt: time.Now(), + AccessCount: 1, + } + + // 检查是否有相同内容的缓存 + var existingItem *CacheItem + cm.items.Range(func(k, v interface{}) bool { + if i := v.(*CacheItem); i.Hash == item.Hash { + existingItem = i + return false + } + return true + }) + + if existingItem != nil { + // 如果找到相同内容的缓存,删除新文件,复用现有缓存 + os.Remove(filePath) + cm.items.Store(key, existingItem) + log.Printf("[Cache] Found duplicate content for %s, reusing existing cache", key.URL) + return existingItem, nil + } + + // 存储新的缓存项 + cm.items.Store(key, item) + log.Printf("[Cache] Cached %s (%s)", key.URL, formatBytes(item.Size)) + return item, nil +} + +// shouldCache 检查响应是否应该被缓存 +func (cm *CacheManager) shouldCache(resp *http.Response) bool { + // 检查状态码 + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotModified { + return false + } + + // 解析 Cache-Control 头 + cacheControl := resp.Header.Get("Cache-Control") + if strings.Contains(cacheControl, "no-store") || + strings.Contains(cacheControl, "no-cache") || + strings.Contains(cacheControl, "private") { + return false + } + + return true +} + +// cleanup 定期清理过期的缓存项 +func (cm *CacheManager) cleanup() { + ticker := time.NewTicker(cm.cleanupTick) + for range ticker.C { + var totalSize int64 + var keysToDelete []CacheKey + + // 收集需要删除的键和计算总大小 + cm.items.Range(func(k, v interface{}) bool { + key := k.(CacheKey) + item := v.(*CacheItem) + totalSize += item.Size + + if time.Since(item.LastAccess) > cm.maxAge { + keysToDelete = append(keysToDelete, key) + } + return true + }) + + // 如果总大小超过限制,按最后访问时间排序删除 + if totalSize > cm.maxCacheSize { + var items []*CacheItem + cm.items.Range(func(k, v interface{}) bool { + items = append(items, v.(*CacheItem)) + return true + }) + + // 按最后访问时间排序 + sort.Slice(items, func(i, j int) bool { + return items[i].LastAccess.Before(items[j].LastAccess) + }) + + // 删除最旧的项直到总大小小于限制 + for _, item := range items { + if totalSize <= cm.maxCacheSize { + break + } + cm.items.Range(func(k, v interface{}) bool { + if v.(*CacheItem) == item { + keysToDelete = append(keysToDelete, k.(CacheKey)) + totalSize -= item.Size + return false + } + return true + }) + } + } + + // 删除过期和超出大小限制的缓存项 + for _, key := range keysToDelete { + if item, ok := cm.items.Load(key); ok { + cacheItem := item.(*CacheItem) + os.Remove(cacheItem.FilePath) + cm.items.Delete(key) + log.Printf("[Cache] Removed expired item: %s", key.URL) + } + } + } +} + +// formatBytes 格式化字节大小 +func formatBytes(bytes int64) string { + const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + ) + + switch { + case bytes >= GB: + return fmt.Sprintf("%.2f GB", float64(bytes)/float64(GB)) + case bytes >= MB: + return fmt.Sprintf("%.2f MB", float64(bytes)/float64(MB)) + case bytes >= KB: + return fmt.Sprintf("%.2f KB", float64(bytes)/float64(KB)) + default: + return fmt.Sprintf("%d B", bytes) + } +} + +// GetStats 获取缓存统计信息 +func (cm *CacheManager) GetStats() CacheStats { + var totalItems int + var totalSize int64 + + cm.items.Range(func(_, value interface{}) bool { + item := value.(*CacheItem) + totalItems++ + totalSize += item.Size + return true + }) + + hitCount := cm.hitCount.Load() + missCount := cm.missCount.Load() + totalRequests := hitCount + missCount + hitRate := float64(0) + if totalRequests > 0 { + hitRate = float64(hitCount) / float64(totalRequests) * 100 + } + + return CacheStats{ + TotalItems: totalItems, + TotalSize: totalSize, + HitCount: hitCount, + MissCount: missCount, + HitRate: hitRate, + BytesSaved: cm.bytesSaved.Load(), + Enabled: cm.enabled.Load(), + } +} + +// SetEnabled 设置缓存开关状态 +func (cm *CacheManager) SetEnabled(enabled bool) { + cm.enabled.Store(enabled) +} + +// ClearCache 清空缓存 +func (cm *CacheManager) ClearCache() error { + // 删除所有缓存文件 + var keysToDelete []CacheKey + cm.items.Range(func(key, value interface{}) bool { + cacheKey := key.(CacheKey) + item := value.(*CacheItem) + os.Remove(item.FilePath) + keysToDelete = append(keysToDelete, cacheKey) + return true + }) + + // 清除缓存项 + for _, key := range keysToDelete { + cm.items.Delete(key) + } + + // 重置统计信息 + cm.hitCount.Store(0) + cm.missCount.Store(0) + cm.bytesSaved.Store(0) + + return nil +} diff --git a/internal/handler/cache_admin.go b/internal/handler/cache_admin.go new file mode 100644 index 0000000..1226e0f --- /dev/null +++ b/internal/handler/cache_admin.go @@ -0,0 +1,105 @@ +package handler + +import ( + "encoding/json" + "net/http" + "proxy-go/internal/cache" +) + +type CacheAdminHandler struct { + proxyCache *cache.CacheManager + mirrorCache *cache.CacheManager +} + +func NewCacheAdminHandler(proxyCache, mirrorCache *cache.CacheManager) *CacheAdminHandler { + return &CacheAdminHandler{ + proxyCache: proxyCache, + mirrorCache: mirrorCache, + } +} + +// GetCacheStats 获取缓存统计信息 +func (h *CacheAdminHandler) GetCacheStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + stats := map[string]cache.CacheStats{ + "proxy": h.proxyCache.GetStats(), + "mirror": h.mirrorCache.GetStats(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(stats) +} + +// SetCacheEnabled 设置缓存开关状态 +func (h *CacheAdminHandler) SetCacheEnabled(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Type string `json:"type"` // "proxy" 或 "mirror" + Enabled bool `json:"enabled"` // true 或 false + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + switch req.Type { + case "proxy": + h.proxyCache.SetEnabled(req.Enabled) + case "mirror": + h.mirrorCache.SetEnabled(req.Enabled) + default: + http.Error(w, "Invalid cache type", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) +} + +// ClearCache 清空缓存 +func (h *CacheAdminHandler) ClearCache(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Type string `json:"type"` // "proxy", "mirror" 或 "all" + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + var err error + switch req.Type { + case "proxy": + err = h.proxyCache.ClearCache() + case "mirror": + err = h.mirrorCache.ClearCache() + case "all": + err = h.proxyCache.ClearCache() + if err == nil { + err = h.mirrorCache.ClearCache() + } + default: + http.Error(w, "Invalid cache type", http.StatusBadRequest) + return + } + + if err != nil { + http.Error(w, "Failed to clear cache: "+err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/internal/handler/config.go b/internal/handler/config.go index f04db03..c6ed9f7 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -33,11 +33,6 @@ func (h *ConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -// handleConfigPage 处理配置页面请求 -func (h *ConfigHandler) handleConfigPage(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, "web/templates/config.html") -} - // handleGetConfig 处理获取配置请求 func (h *ConfigHandler) handleGetConfig(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/internal/handler/mirror_proxy.go b/internal/handler/mirror_proxy.go index 2a5b4cb..9ea4b72 100644 --- a/internal/handler/mirror_proxy.go +++ b/internal/handler/mirror_proxy.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "net/url" + "proxy-go/internal/cache" "proxy-go/internal/metrics" "proxy-go/internal/utils" "strings" @@ -14,6 +15,7 @@ import ( type MirrorProxyHandler struct { client *http.Client + Cache *cache.CacheManager } func NewMirrorProxyHandler() *MirrorProxyHandler { @@ -23,11 +25,18 @@ func NewMirrorProxyHandler() *MirrorProxyHandler { IdleConnTimeout: 90 * time.Second, } + // 初始化缓存管理器 + cacheManager, err := cache.NewCacheManager("data/mirror_cache") + if err != nil { + log.Printf("[Cache] Failed to initialize mirror cache manager: %v", err) + } + return &MirrorProxyHandler{ client: &http.Client{ Transport: transport, Timeout: 30 * time.Second, }, + Cache: cacheManager, } } @@ -107,6 +116,23 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyReq.Header.Set("Host", parsedURL.Host) proxyReq.Host = parsedURL.Host + // 检查是否可以使用缓存 + if r.Method == http.MethodGet && h.Cache != nil { + cacheKey := h.Cache.GenerateCacheKey(r) + if item, hit, notModified := h.Cache.Get(cacheKey, r); hit { + // 从缓存提供响应 + w.Header().Set("Content-Type", item.ContentType) + w.Header().Set("Proxy-Go-Cache", "HIT") + if notModified { + w.WriteHeader(http.StatusNotModified) + return + } + http.ServeFile(w, r, item.FilePath) + collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(startTime), item.Size, utils.GetClientIP(r), r) + return + } + } + // 发送请求 resp, err := h.client.Do(proxyReq) if err != nil { @@ -118,25 +144,42 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() + // 读取响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Error reading response", http.StatusInternalServerError) + log.Printf("Error reading response: %v", err) + return + } + + // 如果是GET请求且响应成功,尝试缓存 + if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil { + cacheKey := h.Cache.GenerateCacheKey(r) + if _, err := h.Cache.Put(cacheKey, resp, body); err != nil { + log.Printf("[Cache] Failed to cache %s: %v", actualURL, err) + } + } + // 复制响应头 copyHeader(w.Header(), resp.Header) + w.Header().Set("Proxy-Go-Cache", "MISS") // 设置状态码 w.WriteHeader(resp.StatusCode) - // 复制响应体 - bytesCopied, err := io.Copy(w, resp.Body) + // 写入响应体 + written, err := w.Write(body) if err != nil { - log.Printf("Error copying response: %v", err) + log.Printf("Error writing response: %v", err) return } // 记录访问日志 log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s", r.Method, resp.StatusCode, time.Since(startTime), - utils.GetClientIP(r), utils.FormatBytes(bytesCopied), + utils.GetClientIP(r), utils.FormatBytes(int64(written)), utils.GetRequestSource(r), actualURL) // 记录统计信息 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), bytesCopied, utils.GetClientIP(r), r) + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) } diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index ac25453..a95af45 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "proxy-go/internal/cache" "proxy-go/internal/config" "proxy-go/internal/metrics" "proxy-go/internal/utils" @@ -17,7 +18,6 @@ import ( "time" "golang.org/x/net/http2" - "golang.org/x/time/rate" ) const ( @@ -30,15 +30,6 @@ const ( backendServTimeout = 40 * time.Second idleConnTimeout = 120 * time.Second tlsHandshakeTimeout = 10 * time.Second - - // 限流相关常量 - globalRateLimit = 2000 - globalBurstLimit = 500 - perHostRateLimit = 200 - perHostBurstLimit = 100 - perIPRateLimit = 50 - perIPBurstLimit = 20 - cleanupInterval = 10 * time.Minute ) // 统一的缓冲池 @@ -48,17 +39,6 @@ var bufferPool = sync.Pool{ }, } -// getBuffer 获取缓冲区 -func getBuffer() (*bytes.Buffer, func()) { - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - return buf, func() { - if buf != nil { - bufferPool.Put(buf) - } - } -} - // 添加 hop-by-hop 头部映射 var hopHeadersMap = make(map[string]bool) @@ -82,115 +62,14 @@ func init() { // ErrorHandler 定义错误处理函数类型 type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) -// RateLimiter 定义限流器接口 -type RateLimiter interface { - Allow() bool - Clean(now time.Time) -} - -// 限流管理器 -type rateLimitManager struct { - globalLimiter *rate.Limiter - hostLimiters *sync.Map // host -> *rate.Limiter - ipLimiters *sync.Map // IP -> *rate.Limiter - lastCleanup time.Time -} - -// 创建新的限流管理器 -func newRateLimitManager() *rateLimitManager { - manager := &rateLimitManager{ - globalLimiter: rate.NewLimiter(rate.Limit(globalRateLimit), globalBurstLimit), - hostLimiters: &sync.Map{}, - ipLimiters: &sync.Map{}, - lastCleanup: time.Now(), - } - - // 启动清理协程 - go manager.cleanupLoop() - return manager -} - -func (m *rateLimitManager) cleanupLoop() { - ticker := time.NewTicker(cleanupInterval) - for range ticker.C { - now := time.Now() - m.cleanup(now) - } -} - -func (m *rateLimitManager) cleanup(now time.Time) { - m.hostLimiters.Range(func(key, value interface{}) bool { - if now.Sub(m.lastCleanup) > cleanupInterval { - m.hostLimiters.Delete(key) - } - return true - }) - - m.ipLimiters.Range(func(key, value interface{}) bool { - if now.Sub(m.lastCleanup) > cleanupInterval { - m.ipLimiters.Delete(key) - } - return true - }) - - m.lastCleanup = now -} - -func (m *rateLimitManager) getHostLimiter(host string) *rate.Limiter { - if limiter, exists := m.hostLimiters.Load(host); exists { - return limiter.(*rate.Limiter) - } - - limiter := rate.NewLimiter(rate.Limit(perHostRateLimit), perHostBurstLimit) - m.hostLimiters.Store(host, limiter) - return limiter -} - -func (m *rateLimitManager) getIPLimiter(ip string) *rate.Limiter { - if limiter, exists := m.ipLimiters.Load(ip); exists { - return limiter.(*rate.Limiter) - } - - limiter := rate.NewLimiter(rate.Limit(perIPRateLimit), perIPBurstLimit) - m.ipLimiters.Store(ip, limiter) - return limiter -} - -// 检查是否允许请求 -func (m *rateLimitManager) allowRequest(r *http.Request) error { - // 全局限流检查 - if !m.globalLimiter.Allow() { - return fmt.Errorf("global rate limit exceeded") - } - - // Host限流检查 - host := r.Host - if host != "" { - if !m.getHostLimiter(host).Allow() { - return fmt.Errorf("host rate limit exceeded for %s", host) - } - } - - // IP限流检查 - ip := utils.GetClientIP(r) - if ip != "" { - if !m.getIPLimiter(ip).Allow() { - return fmt.Errorf("ip rate limit exceeded for %s", ip) - } - } - - return nil -} - type ProxyHandler struct { pathMap map[string]config.PathConfig client *http.Client - limiter *rate.Limiter startTime time.Time config *config.Config auth *authManager - errorHandler ErrorHandler // 添加错误处理器 - rateLimiter *rateLimitManager + errorHandler ErrorHandler + Cache *cache.CacheManager } // NewProxyHandler 创建新的代理处理器 @@ -202,12 +81,12 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler { transport := &http.Transport{ DialContext: dialer.DialContext, - MaxIdleConns: 1000, // 增加最大空闲连接数 - MaxIdleConnsPerHost: 100, // 增加每个主机的最大空闲连接数 + MaxIdleConns: 1000, + MaxIdleConnsPerHost: 100, IdleConnTimeout: idleConnTimeout, TLSHandshakeTimeout: tlsHandshakeTimeout, ExpectContinueTimeout: 1 * time.Second, - MaxConnsPerHost: 200, // 增加每个主机的最大连接数 + MaxConnsPerHost: 200, DisableKeepAlives: false, DisableCompression: false, ForceAttemptHTTP2: true, @@ -227,6 +106,12 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler { http2Transport.StrictMaxConcurrentStreams = true } + // 初始化缓存管理器 + cacheManager, err := cache.NewCacheManager("data/cache") + if err != nil { + log.Printf("[Cache] Failed to initialize cache manager: %v", err) + } + handler := &ProxyHandler{ pathMap: cfg.MAP, client: &http.Client{ @@ -239,20 +124,14 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler { return nil }, }, - limiter: rate.NewLimiter(rate.Limit(5000), 10000), - startTime: time.Now(), - config: cfg, - auth: newAuthManager(), - rateLimiter: newRateLimitManager(), + startTime: time.Now(), + config: cfg, + auth: newAuthManager(), + Cache: cacheManager, errorHandler: func(w http.ResponseWriter, r *http.Request, err error) { log.Printf("[Error] %s %s -> %v", r.Method, r.URL.Path, err) - if strings.Contains(err.Error(), "rate limit exceeded") { - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte("Too Many Requests")) - } else { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Internal Server Error")) - } + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) }, } @@ -266,13 +145,6 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler { return handler } -// SetErrorHandler 允许自定义错误处理函数 -func (h *ProxyHandler) SetErrorHandler(handler ErrorHandler) { - if handler != nil { - h.errorHandler = handler - } -} - // copyResponse 使用缓冲方式传输数据 func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) { buf := bufferPool.Get().(*bytes.Buffer) @@ -329,12 +201,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { collector.BeginRequest() defer collector.EndRequest() - // 限流检查 - if err := h.rateLimiter.allowRequest(r); err != nil { - h.errorHandler(w, r, err) - return - } - start := time.Now() // 创建带超时的上下文 @@ -457,6 +323,23 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // 检查是否可以使用缓存 + if r.Method == http.MethodGet && h.Cache != nil { + cacheKey := h.Cache.GenerateCacheKey(r) + if item, hit, notModified := h.Cache.Get(cacheKey, r); hit { + // 从缓存提供响应 + w.Header().Set("Content-Type", item.ContentType) + w.Header().Set("Proxy-Go-Cache", "HIT") + if notModified { + w.WriteHeader(http.StatusNotModified) + return + } + http.ServeFile(w, r, item.FilePath) + collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(start), item.Size, utils.GetClientIP(r), r) + return + } + } + // 发送代理请求 resp, err := h.client.Do(proxyReq) if err != nil { @@ -473,55 +356,37 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() - copyHeader(w.Header(), resp.Header) + // 读取响应体到缓冲区 + buf := new(bytes.Buffer) + _, err = copyResponse(buf, resp.Body, nil) + if err != nil { + h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err)) + return + } + body := buf.Bytes() - // 删除严格的 CSP + // 如果是GET请求且响应成功,尝试缓存 + if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil { + cacheKey := h.Cache.GenerateCacheKey(r) + if _, err := h.Cache.Put(cacheKey, resp, body); err != nil { + log.Printf("[Cache] Failed to cache %s: %v", r.URL.Path, err) + } + } + + // 设置响应头 + copyHeader(w.Header(), resp.Header) + w.Header().Set("Proxy-Go-Cache", "MISS") w.Header().Del("Content-Security-Policy") - // 根据响应大小选择不同的处理策略 - contentLength := resp.ContentLength - if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应 - // 获取合适大小的缓冲区 - buf, putBuffer := getBuffer() - defer putBuffer() - - // 使用缓冲区读取响应 - _, err := io.Copy(buf, resp.Body) - if err != nil { - if !isConnectionClosed(err) { - h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err)) - } - return - } - - // 设置响应状态码并一次性写入响应 - w.WriteHeader(resp.StatusCode) - written, err := w.Write(buf.Bytes()) - if err != nil { - if !isConnectionClosed(err) { - log.Printf("Error writing response: %v", err) - } - } - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), utils.GetClientIP(r), r) - } else { - // 大响应使用零拷贝传输 - w.WriteHeader(resp.StatusCode) - var bytesCopied int64 - var err error - - if f, ok := w.(http.Flusher); ok { - bytesCopied, err = copyResponse(w, resp.Body, f) - } else { - bytesCopied, err = copyResponse(w, resp.Body, nil) - } - - if err != nil && !isConnectionClosed(err) { - log.Printf("Error copying response: %v", err) - } - - // 记录访问日志 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), bytesCopied, utils.GetClientIP(r), r) + // 写入响应 + w.WriteHeader(resp.StatusCode) + n, err := w.Write(body) + if err != nil && !isConnectionClosed(err) { + log.Printf("Error writing response: %v", err) } + + // 记录访问日志 + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(n), utils.GetClientIP(r), r) } func copyHeader(dst, src http.Header) { diff --git a/internal/metrics/collector.go b/internal/metrics/collector.go index 4d86bb0..c689a40 100644 --- a/internal/metrics/collector.go +++ b/internal/metrics/collector.go @@ -186,12 +186,12 @@ func (c *Collector) GetStats() map[string]interface{} { // 收集路径统计 var pathMetrics []models.PathMetrics c.pathStats.Range(func(key, value interface{}) bool { - stats := value.(models.PathMetrics) + stats := value.(*models.PathMetrics) if stats.RequestCount > 0 { avgLatencyMs := float64(stats.TotalLatency) / float64(stats.RequestCount) / float64(time.Millisecond) stats.AvgLatency = fmt.Sprintf("%.2fms", avgLatencyMs) } - pathMetrics = append(pathMetrics, stats) + pathMetrics = append(pathMetrics, *stats) return true }) diff --git a/main.go b/main.go index 0b7a799..d976d44 100644 --- a/main.go +++ b/main.go @@ -80,6 +80,12 @@ func main() { proxyHandler.AuthMiddleware(handler.NewConfigHandler(cfg).ServeHTTP)(w, r) case "/admin/api/config/save": proxyHandler.AuthMiddleware(handler.NewConfigHandler(cfg).ServeHTTP)(w, r) + case "/admin/api/cache/stats": + proxyHandler.AuthMiddleware(handler.NewCacheAdminHandler(proxyHandler.Cache, mirrorHandler.Cache).GetCacheStats)(w, r) + case "/admin/api/cache/enable": + proxyHandler.AuthMiddleware(handler.NewCacheAdminHandler(proxyHandler.Cache, mirrorHandler.Cache).SetCacheEnabled)(w, r) + case "/admin/api/cache/clear": + proxyHandler.AuthMiddleware(handler.NewCacheAdminHandler(proxyHandler.Cache, mirrorHandler.Cache).ClearCache)(w, r) default: http.NotFound(w, r) } diff --git a/web/app/dashboard/cache/page.tsx b/web/app/dashboard/cache/page.tsx new file mode 100644 index 0000000..88a5afe --- /dev/null +++ b/web/app/dashboard/cache/page.tsx @@ -0,0 +1,236 @@ +"use client" + +import { useEffect, useState, useCallback } from "react" +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card" +import { Button } from "@/components/ui/button" +import { useToast } from "@/components/ui/use-toast" +import { Switch } from "@/components/ui/switch" + +interface CacheStats { + total_items: number + total_size: number + hit_count: number + miss_count: number + hit_rate: number + bytes_saved: number + enabled: boolean +} + +interface CacheData { + proxy: CacheStats + mirror: CacheStats +} + +function formatBytes(bytes: number) { + const units = ['B', 'KB', 'MB', 'GB'] + let size = bytes + let unitIndex = 0 + + while (size >= 1024 && unitIndex < units.length - 1) { + size /= 1024 + unitIndex++ + } + + return `${size.toFixed(2)} ${units[unitIndex]}` +} + +export default function CachePage() { + const [stats, setStats] = useState(null) + const [loading, setLoading] = useState(true) + const { toast } = useToast() + + const fetchStats = useCallback(async () => { + try { + const response = await fetch("/admin/api/cache/stats") + if (!response.ok) throw new Error("获取缓存统计失败") + const data = await response.json() + setStats(data) + } catch (error) { + toast({ + title: "错误", + description: error instanceof Error ? error.message : "获取缓存统计失败", + variant: "destructive", + }) + } finally { + setLoading(false) + } + }, [toast]) + + useEffect(() => { + // 立即获取一次数据 + fetchStats() + + // 设置定时刷新 + const interval = setInterval(fetchStats, 5000) + return () => clearInterval(interval) + }, [fetchStats]) + + const handleToggleCache = async (type: "proxy" | "mirror", enabled: boolean) => { + try { + const response = await fetch("/admin/api/cache/enable", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ type, enabled }), + }) + + if (!response.ok) throw new Error("切换缓存状态失败") + + toast({ + title: "成功", + description: `${type === "proxy" ? "代理" : "镜像"}缓存已${enabled ? "启用" : "禁用"}`, + }) + + fetchStats() + } catch (error) { + toast({ + title: "错误", + description: error instanceof Error ? error.message : "切换缓存状态失败", + variant: "destructive", + }) + } + } + + const handleClearCache = async (type: "proxy" | "mirror" | "all") => { + try { + const response = await fetch("/admin/api/cache/clear", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ type }), + }) + + if (!response.ok) throw new Error("清理缓存失败") + + toast({ + title: "成功", + description: "缓存已清理", + }) + + fetchStats() + } catch (error) { + toast({ + title: "错误", + description: error instanceof Error ? error.message : "清理缓存失败", + variant: "destructive", + }) + } + } + + if (loading) { + return ( +
+
+
加载中...
+
正在获取缓存统计信息
+
+
+ ) + } + + return ( +
+
+ +
+ +
+ {/* 代理缓存 */} + + + 代理缓存 +
+ handleToggleCache("proxy", checked)} + /> + +
+
+ +
+
+
缓存项数量
+
{stats?.proxy.total_items ?? 0}
+
+
+
总大小
+
{formatBytes(stats?.proxy.total_size ?? 0)}
+
+
+
命中次数
+
{stats?.proxy.hit_count ?? 0}
+
+
+
未命中次数
+
{stats?.proxy.miss_count ?? 0}
+
+
+
命中率
+
{(stats?.proxy.hit_rate ?? 0).toFixed(2)}%
+
+
+
节省带宽
+
{formatBytes(stats?.proxy.bytes_saved ?? 0)}
+
+
+
+
+ + {/* 镜像缓存 */} + + + 镜像缓存 +
+ handleToggleCache("mirror", checked)} + /> + +
+
+ +
+
+
缓存项数量
+
{stats?.mirror.total_items ?? 0}
+
+
+
总大小
+
{formatBytes(stats?.mirror.total_size ?? 0)}
+
+
+
命中次数
+
{stats?.mirror.hit_count ?? 0}
+
+
+
未命中次数
+
{stats?.mirror.miss_count ?? 0}
+
+
+
命中率
+
{(stats?.mirror.hit_rate ?? 0).toFixed(2)}%
+
+
+
节省带宽
+
{formatBytes(stats?.mirror.bytes_saved ?? 0)}
+
+
+
+
+
+
+ ) +} \ No newline at end of file diff --git a/web/components/nav.tsx b/web/components/nav.tsx index aceb470..f39c18c 100644 --- a/web/components/nav.tsx +++ b/web/components/nav.tsx @@ -49,6 +49,12 @@ export function Nav() { > 配置 + + 缓存 +