diff --git a/internal/cache/manager.go b/internal/cache/manager.go index 138843f..d771b41 100644 --- a/internal/cache/manager.go +++ b/internal/cache/manager.go @@ -19,42 +19,23 @@ import ( // CacheKey 用于标识缓存项的唯一键 type CacheKey struct { - URL string - AcceptHeaders string - UserAgent string - VaryHeadersMap map[string]string // 存储 Vary 头部的值 + URL string + AcceptHeaders string + UserAgent string + VaryHeaders string // 存储 Vary 头部的值,格式:key1=value1&key2=value2 } // String 实现 Stringer 接口,用于生成唯一的字符串表示 func (k CacheKey) String() string { - // 将 VaryHeadersMap 转换为有序的字符串 - var varyPairs []string - for key, value := range k.VaryHeadersMap { - varyPairs = append(varyPairs, key+"="+value) - } - sort.Strings(varyPairs) - varyStr := strings.Join(varyPairs, "&") - - return fmt.Sprintf("%s|%s|%s|%s", k.URL, k.AcceptHeaders, k.UserAgent, varyStr) + return fmt.Sprintf("%s|%s|%s|%s", k.URL, k.AcceptHeaders, k.UserAgent, k.VaryHeaders) } // Equal 比较两个 CacheKey 是否相等 func (k CacheKey) Equal(other CacheKey) bool { - if k.URL != other.URL || k.AcceptHeaders != other.AcceptHeaders || k.UserAgent != other.UserAgent { - return false - } - - if len(k.VaryHeadersMap) != len(other.VaryHeadersMap) { - return false - } - - for key, value := range k.VaryHeadersMap { - if otherValue, ok := other.VaryHeadersMap[key]; !ok || value != otherValue { - return false - } - } - - return true + return k.URL == other.URL && + k.AcceptHeaders == other.AcceptHeaders && + k.UserAgent == other.UserAgent && + k.VaryHeaders == other.VaryHeaders } // Hash 生成 CacheKey 的哈希值 @@ -128,10 +109,22 @@ func NewCacheManager(cacheDir string) (*CacheManager, error) { // GenerateCacheKey 生成缓存键 func (cm *CacheManager) GenerateCacheKey(r *http.Request) CacheKey { + // 处理 Vary 头部 + varyHeaders := make([]string, 0) + for _, vary := range strings.Split(r.Header.Get("Vary"), ",") { + vary = strings.TrimSpace(vary) + if vary != "" { + value := r.Header.Get(vary) + varyHeaders = append(varyHeaders, vary+"="+value) + } + } + sort.Strings(varyHeaders) + return CacheKey{ URL: r.URL.String(), AcceptHeaders: r.Header.Get("Accept"), UserAgent: r.Header.Get("User-Agent"), + VaryHeaders: strings.Join(varyHeaders, "&"), } } @@ -168,7 +161,9 @@ func (cm *CacheManager) Get(key CacheKey, r *http.Request) (*CacheItem, bool, bo // 检查 Vary 头部 for _, varyHeader := range item.VaryHeaders { - if r.Header.Get(varyHeader) != key.VaryHeadersMap[varyHeader] { + requestValue := r.Header.Get(varyHeader) + varyPair := varyHeader + "=" + requestValue + if !strings.Contains(key.VaryHeaders, varyPair) { cm.missCount.Add(1) return nil, false, false } diff --git a/internal/middleware/fixed_path_proxy.go b/internal/middleware/fixed_path_proxy.go index 1b7ac60..00e141e 100644 --- a/internal/middleware/fixed_path_proxy.go +++ b/internal/middleware/fixed_path_proxy.go @@ -5,6 +5,7 @@ import ( "io" "log" "net/http" + "proxy-go/internal/cache" "proxy-go/internal/config" "proxy-go/internal/metrics" "proxy-go/internal/utils" @@ -19,6 +20,16 @@ type FixedPathConfig struct { TargetURL string `json:"TargetURL"` } +var fixedPathCache *cache.CacheManager + +func init() { + var err error + fixedPathCache, err = cache.NewCacheManager("data/fixed_path_cache") + if err != nil { + log.Printf("[Cache] Failed to initialize fixed path cache manager: %v", err) + } +} + func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -34,6 +45,23 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle targetPath := strings.TrimPrefix(r.URL.Path, cfg.Path) targetURL := cfg.TargetURL + targetPath + // 检查是否可以使用缓存 + if r.Method == http.MethodGet && fixedPathCache != nil { + cacheKey := fixedPathCache.GenerateCacheKey(r) + if item, hit, notModified := fixedPathCache.Get(cacheKey, r); hit { + // 从缓存提供响应 + w.Header().Set("Content-Type", item.ContentType) + w.Header().Set("Proxy-Go-Cache", "HIT") + if notModified { + w.WriteHeader(http.StatusNotModified) + return + } + http.ServeFile(w, r, item.FilePath) + collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(startTime), item.Size, utils.GetClientIP(r), r) + return + } + } + proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { http.Error(w, "Error creating proxy request", http.StatusInternalServerError) @@ -66,24 +94,41 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle } defer resp.Body.Close() + // 读取响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Error reading response", http.StatusInternalServerError) + log.Printf("Error reading response: %v", err) + return + } + + // 如果是GET请求且响应成功,尝试缓存 + if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil { + cacheKey := fixedPathCache.GenerateCacheKey(r) + if _, err := fixedPathCache.Put(cacheKey, resp, body); err != nil { + log.Printf("[Cache] Failed to cache %s: %v", targetURL, err) + } + } + // 复制响应头 for key, values := range resp.Header { for _, value := range values { w.Header().Add(key, value) } } + w.Header().Set("Proxy-Go-Cache", "MISS") // 设置响应状态码 w.WriteHeader(resp.StatusCode) - // 复制响应体 - bytesCopied, err := io.Copy(w, resp.Body) - if err := handleCopyError(err); err != nil { - log.Printf("[%s] Error copying response: %v", utils.GetClientIP(r), err) + // 写入响应体 + written, err := w.Write(body) + if err != nil && !isConnectionClosed(err) { + log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err) } // 记录统计信息 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), bytesCopied, utils.GetClientIP(r), r) + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) return } @@ -95,9 +140,9 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle } } -func handleCopyError(err error) error { +func isConnectionClosed(err error) bool { if err == nil { - return nil + return false } // 忽略常见的连接关闭错误 @@ -105,8 +150,8 @@ func handleCopyError(err error) error { errors.Is(err, syscall.ECONNRESET) || // connection reset by peer strings.Contains(err.Error(), "broken pipe") || strings.Contains(err.Error(), "connection reset by peer") { - return nil + return true } - return err + return false }