diff --git a/internal/cache/manager.go b/internal/cache/manager.go index e44d69a..f9123d5 100644 --- a/internal/cache/manager.go +++ b/internal/cache/manager.go @@ -350,3 +350,58 @@ func (cm *CacheManager) ClearCache() error { return nil } + +// CreateTemp 创建临时缓存文件 +func (cm *CacheManager) CreateTemp(key CacheKey, resp *http.Response) (*os.File, error) { + if !cm.enabled.Load() { + return nil, fmt.Errorf("cache is disabled") + } + + // 创建临时文件 + tempFile, err := os.CreateTemp(cm.cacheDir, "temp-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %v", err) + } + + return tempFile, nil +} + +// Commit 提交缓存文件 +func (cm *CacheManager) Commit(key CacheKey, tempPath string, resp *http.Response, size int64) error { + if !cm.enabled.Load() { + os.Remove(tempPath) + return fmt.Errorf("cache is disabled") + } + + // 生成最终的缓存文件名 + h := sha256.New() + h.Write([]byte(key.String())) + hashStr := hex.EncodeToString(h.Sum(nil)) + ext := filepath.Ext(key.URL) + if ext == "" { + ext = ".bin" + } + filePath := filepath.Join(cm.cacheDir, hashStr+ext) + + // 重命名临时文件 + if err := os.Rename(tempPath, filePath); err != nil { + os.Remove(tempPath) + return fmt.Errorf("failed to rename temp file: %v", err) + } + + // 创建缓存项 + item := &CacheItem{ + FilePath: filePath, + ContentType: resp.Header.Get("Content-Type"), + Size: size, + LastAccess: time.Now(), + Hash: hashStr, + CreatedAt: time.Now(), + AccessCount: 1, + } + + cm.items.Store(key, item) + cm.bytesSaved.Add(size) + log.Printf("[Cache] Cached %s (%s)", key.URL, formatBytes(size)) + return nil +} diff --git a/internal/handler/mirror_proxy.go b/internal/handler/mirror_proxy.go index 9ea4b72..4e13e1a 100644 --- a/internal/handler/mirror_proxy.go +++ b/internal/handler/mirror_proxy.go @@ -144,22 +144,6 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() - // 读取响应体 - body, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, "Error reading response", http.StatusInternalServerError) - log.Printf("Error reading response: %v", err) - return - } - - // 如果是GET请求且响应成功,尝试缓存 - if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil { - cacheKey := h.Cache.GenerateCacheKey(r) - if _, err := h.Cache.Put(cacheKey, resp, body); err != nil { - log.Printf("[Cache] Failed to cache %s: %v", actualURL, err) - } - } - // 复制响应头 copyHeader(w.Header(), resp.Header) w.Header().Set("Proxy-Go-Cache", "MISS") @@ -167,19 +151,38 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 设置状态码 w.WriteHeader(resp.StatusCode) - // 写入响应体 - written, err := w.Write(body) - if err != nil { - log.Printf("Error writing response: %v", err) - return + var written int64 + // 如果是GET请求且响应成功,使用TeeReader同时写入缓存 + if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil { + cacheKey := h.Cache.GenerateCacheKey(r) + if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil { + defer cacheFile.Close() + teeReader := io.TeeReader(resp.Body, cacheFile) + written, err = io.Copy(w, teeReader) + if err == nil { + h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written) + } + } else { + written, err = io.Copy(w, resp.Body) + if err != nil && !isConnectionClosed(err) { + log.Printf("Error writing response: %v", err) + return + } + } + } else { + written, err = io.Copy(w, resp.Body) + if err != nil && !isConnectionClosed(err) { + log.Printf("Error writing response: %v", err) + return + } } // 记录访问日志 log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s", r.Method, resp.StatusCode, time.Since(startTime), - utils.GetClientIP(r), utils.FormatBytes(int64(written)), + utils.GetClientIP(r), utils.FormatBytes(written), utils.GetRequestSource(r), actualURL) // 记录统计信息 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r) } diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index a95af45..dac0aba 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -1,7 +1,6 @@ package handler import ( - "bytes" "context" "fmt" "io" @@ -14,16 +13,12 @@ import ( "proxy-go/internal/metrics" "proxy-go/internal/utils" "strings" - "sync" "time" "golang.org/x/net/http2" ) const ( - // 缓冲区大小 - defaultBufferSize = 32 * 1024 // 32KB - // 超时时间常量 clientConnTimeout = 10 * time.Second proxyRespTimeout = 60 * time.Second @@ -32,13 +27,6 @@ const ( tlsHandshakeTimeout = 10 * time.Second ) -// 统一的缓冲池 -var bufferPool = sync.Pool{ - New: func() interface{} { - return bytes.NewBuffer(make([]byte, defaultBufferSize)) - }, -} - // 添加 hop-by-hop 头部映射 var hopHeadersMap = make(map[string]bool) @@ -145,49 +133,6 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler { return handler } -// copyResponse 使用缓冲方式传输数据 -func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) { - buf := bufferPool.Get().(*bytes.Buffer) - defer bufferPool.Put(buf) - buf.Reset() - - var written int64 - for { - // 清空缓冲区 - buf.Reset() - - // 读取数据到缓冲区 - _, er := io.CopyN(buf, src, defaultBufferSize) - if er != nil && er != io.EOF { - return written, er - } - - // 如果有数据,写入目标 - if buf.Len() > 0 { - nw, ew := dst.Write(buf.Bytes()) - if ew != nil { - return written, ew - } - written += int64(nw) - - // 定期刷新缓冲区 - if flusher != nil && written%(1024*1024) == 0 { // 每1MB刷新一次 - flusher.Flush() - } - } - - if er == io.EOF { - break - } - } - - // 最后一次刷新 - if flusher != nil { - flusher.Flush() - } - return written, nil -} - func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 添加 panic 恢复 defer func() { @@ -356,37 +301,41 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() - // 读取响应体到缓冲区 - buf := new(bytes.Buffer) - _, err = copyResponse(buf, resp.Body, nil) - if err != nil { - h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err)) - return - } - body := buf.Bytes() + // 复制响应头 + copyHeader(w.Header(), resp.Header) + w.Header().Set("Proxy-Go-Cache", "MISS") - // 如果是GET请求且响应成功,尝试缓存 + // 设置响应状态码 + w.WriteHeader(resp.StatusCode) + + var written int64 + // 如果是GET请求且响应成功,使用TeeReader同时写入缓存 if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil { cacheKey := h.Cache.GenerateCacheKey(r) - if _, err := h.Cache.Put(cacheKey, resp, body); err != nil { - log.Printf("[Cache] Failed to cache %s: %v", r.URL.Path, err) + if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil { + defer cacheFile.Close() + teeReader := io.TeeReader(resp.Body, cacheFile) + written, err = io.Copy(w, teeReader) + if err == nil { + h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written) + } + } else { + written, err = io.Copy(w, resp.Body) + if err != nil && !isConnectionClosed(err) { + log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err) + return + } + } + } else { + written, err = io.Copy(w, resp.Body) + if err != nil && !isConnectionClosed(err) { + log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err) + return } } - // 设置响应头 - copyHeader(w.Header(), resp.Header) - w.Header().Set("Proxy-Go-Cache", "MISS") - w.Header().Del("Content-Security-Policy") - - // 写入响应 - w.WriteHeader(resp.StatusCode) - n, err := w.Write(body) - if err != nil && !isConnectionClosed(err) { - log.Printf("Error writing response: %v", err) - } - - // 记录访问日志 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(n), utils.GetClientIP(r), r) + // 记录统计信息 + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, utils.GetClientIP(r), r) } func copyHeader(dst, src http.Header) { diff --git a/internal/middleware/fixed_path_proxy.go b/internal/middleware/fixed_path_proxy.go index db9366d..f77a97e 100644 --- a/internal/middleware/fixed_path_proxy.go +++ b/internal/middleware/fixed_path_proxy.go @@ -94,22 +94,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle } defer resp.Body.Close() - // 读取响应体 - body, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, "Error reading response", http.StatusInternalServerError) - log.Printf("Error reading response: %v", err) - return - } - - // 如果是GET请求且响应成功,尝试缓存 - if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil { - cacheKey := fixedPathCache.GenerateCacheKey(r) - if _, err := fixedPathCache.Put(cacheKey, resp, body); err != nil { - log.Printf("[Cache] Failed to cache %s: %v", targetURL, err) - } - } - // 复制响应头 for key, values := range resp.Header { for _, value := range values { @@ -121,14 +105,30 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle // 设置响应状态码 w.WriteHeader(resp.StatusCode) - // 写入响应体 - written, err := w.Write(body) + var written int64 + // 如果是GET请求且响应成功,使用TeeReader同时写入缓存 + if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil { + cacheKey := fixedPathCache.GenerateCacheKey(r) + if cacheFile, err := fixedPathCache.CreateTemp(cacheKey, resp); err == nil { + defer cacheFile.Close() + teeReader := io.TeeReader(resp.Body, cacheFile) + written, err = io.Copy(w, teeReader) + if err == nil { + fixedPathCache.Commit(cacheKey, cacheFile.Name(), resp, written) + } + } else { + written, err = io.Copy(w, resp.Body) + } + } else { + written, err = io.Copy(w, resp.Body) + } + if err != nil && !isConnectionClosed(err) { log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err) } // 记录统计信息 - collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) + collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r) return }