feat(cache): Implement streaming cache with TeeReader for efficient response handling

- Replace full response body reading with streaming cache mechanism
- Add CreateTemp and Commit methods to CacheManager for incremental caching
- Use io.TeeReader to simultaneously write response to client and cache file
- Remove buffer pool and full body reading in proxy and mirror handlers
- Improve memory efficiency and reduce latency for large responses
- Update error handling and logging for cache-related operations
This commit is contained in:
wood chen 2025-02-15 17:48:13 +08:00
parent ffc64bb73a
commit 2267a27b37
4 changed files with 129 additions and 122 deletions

View File

@ -350,3 +350,58 @@ func (cm *CacheManager) ClearCache() error {
return nil return nil
} }
// CreateTemp 创建临时缓存文件
func (cm *CacheManager) CreateTemp(key CacheKey, resp *http.Response) (*os.File, error) {
if !cm.enabled.Load() {
return nil, fmt.Errorf("cache is disabled")
}
// 创建临时文件
tempFile, err := os.CreateTemp(cm.cacheDir, "temp-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %v", err)
}
return tempFile, nil
}
// Commit 提交缓存文件
func (cm *CacheManager) Commit(key CacheKey, tempPath string, resp *http.Response, size int64) error {
if !cm.enabled.Load() {
os.Remove(tempPath)
return fmt.Errorf("cache is disabled")
}
// 生成最终的缓存文件名
h := sha256.New()
h.Write([]byte(key.String()))
hashStr := hex.EncodeToString(h.Sum(nil))
ext := filepath.Ext(key.URL)
if ext == "" {
ext = ".bin"
}
filePath := filepath.Join(cm.cacheDir, hashStr+ext)
// 重命名临时文件
if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to rename temp file: %v", err)
}
// 创建缓存项
item := &CacheItem{
FilePath: filePath,
ContentType: resp.Header.Get("Content-Type"),
Size: size,
LastAccess: time.Now(),
Hash: hashStr,
CreatedAt: time.Now(),
AccessCount: 1,
}
cm.items.Store(key, item)
cm.bytesSaved.Add(size)
log.Printf("[Cache] Cached %s (%s)", key.URL, formatBytes(size))
return nil
}

View File

@ -144,22 +144,6 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer resp.Body.Close() defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Error reading response", http.StatusInternalServerError)
log.Printf("Error reading response: %v", err)
return
}
// 如果是GET请求且响应成功尝试缓存
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
cacheKey := h.Cache.GenerateCacheKey(r)
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
log.Printf("[Cache] Failed to cache %s: %v", actualURL, err)
}
}
// 复制响应头 // 复制响应头
copyHeader(w.Header(), resp.Header) copyHeader(w.Header(), resp.Header)
w.Header().Set("Proxy-Go-Cache", "MISS") w.Header().Set("Proxy-Go-Cache", "MISS")
@ -167,19 +151,38 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 设置状态码 // 设置状态码
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
// 写入响应体 var written int64
written, err := w.Write(body) // 如果是GET请求且响应成功使用TeeReader同时写入缓存
if err != nil { if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
cacheKey := h.Cache.GenerateCacheKey(r)
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
defer cacheFile.Close()
teeReader := io.TeeReader(resp.Body, cacheFile)
written, err = io.Copy(w, teeReader)
if err == nil {
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
}
} else {
written, err = io.Copy(w, resp.Body)
if err != nil && !isConnectionClosed(err) {
log.Printf("Error writing response: %v", err) log.Printf("Error writing response: %v", err)
return return
} }
}
} else {
written, err = io.Copy(w, resp.Body)
if err != nil && !isConnectionClosed(err) {
log.Printf("Error writing response: %v", err)
return
}
}
// 记录访问日志 // 记录访问日志
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s",
r.Method, resp.StatusCode, time.Since(startTime), r.Method, resp.StatusCode, time.Since(startTime),
utils.GetClientIP(r), utils.FormatBytes(int64(written)), utils.GetClientIP(r), utils.FormatBytes(written),
utils.GetRequestSource(r), actualURL) utils.GetRequestSource(r), actualURL)
// 记录统计信息 // 记录统计信息
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
} }

View File

@ -1,7 +1,6 @@
package handler package handler
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io" "io"
@ -14,16 +13,12 @@ import (
"proxy-go/internal/metrics" "proxy-go/internal/metrics"
"proxy-go/internal/utils" "proxy-go/internal/utils"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
const ( const (
// 缓冲区大小
defaultBufferSize = 32 * 1024 // 32KB
// 超时时间常量 // 超时时间常量
clientConnTimeout = 10 * time.Second clientConnTimeout = 10 * time.Second
proxyRespTimeout = 60 * time.Second proxyRespTimeout = 60 * time.Second
@ -32,13 +27,6 @@ const (
tlsHandshakeTimeout = 10 * time.Second tlsHandshakeTimeout = 10 * time.Second
) )
// 统一的缓冲池
var bufferPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, defaultBufferSize))
},
}
// 添加 hop-by-hop 头部映射 // 添加 hop-by-hop 头部映射
var hopHeadersMap = make(map[string]bool) var hopHeadersMap = make(map[string]bool)
@ -145,49 +133,6 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler {
return handler return handler
} }
// copyResponse 使用缓冲方式传输数据
func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) {
buf := bufferPool.Get().(*bytes.Buffer)
defer bufferPool.Put(buf)
buf.Reset()
var written int64
for {
// 清空缓冲区
buf.Reset()
// 读取数据到缓冲区
_, er := io.CopyN(buf, src, defaultBufferSize)
if er != nil && er != io.EOF {
return written, er
}
// 如果有数据,写入目标
if buf.Len() > 0 {
nw, ew := dst.Write(buf.Bytes())
if ew != nil {
return written, ew
}
written += int64(nw)
// 定期刷新缓冲区
if flusher != nil && written%(1024*1024) == 0 { // 每1MB刷新一次
flusher.Flush()
}
}
if er == io.EOF {
break
}
}
// 最后一次刷新
if flusher != nil {
flusher.Flush()
}
return written, nil
}
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 添加 panic 恢复 // 添加 panic 恢复
defer func() { defer func() {
@ -356,37 +301,41 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer resp.Body.Close() defer resp.Body.Close()
// 读取响应体到缓冲区 // 复制响应头
buf := new(bytes.Buffer)
_, err = copyResponse(buf, resp.Body, nil)
if err != nil {
h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err))
return
}
body := buf.Bytes()
// 如果是GET请求且响应成功尝试缓存
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
cacheKey := h.Cache.GenerateCacheKey(r)
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
log.Printf("[Cache] Failed to cache %s: %v", r.URL.Path, err)
}
}
// 设置响应头
copyHeader(w.Header(), resp.Header) copyHeader(w.Header(), resp.Header)
w.Header().Set("Proxy-Go-Cache", "MISS") w.Header().Set("Proxy-Go-Cache", "MISS")
w.Header().Del("Content-Security-Policy")
// 写入响应 // 设置响应状态码
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
n, err := w.Write(body)
var written int64
// 如果是GET请求且响应成功使用TeeReader同时写入缓存
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
cacheKey := h.Cache.GenerateCacheKey(r)
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
defer cacheFile.Close()
teeReader := io.TeeReader(resp.Body, cacheFile)
written, err = io.Copy(w, teeReader)
if err == nil {
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
}
} else {
written, err = io.Copy(w, resp.Body)
if err != nil && !isConnectionClosed(err) { if err != nil && !isConnectionClosed(err) {
log.Printf("Error writing response: %v", err) log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
return
}
}
} else {
written, err = io.Copy(w, resp.Body)
if err != nil && !isConnectionClosed(err) {
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
return
}
} }
// 记录访问日志 // 记录统计信息
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(n), utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, utils.GetClientIP(r), r)
} }
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {

View File

@ -94,22 +94,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
} }
defer resp.Body.Close() defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "Error reading response", http.StatusInternalServerError)
log.Printf("Error reading response: %v", err)
return
}
// 如果是GET请求且响应成功尝试缓存
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
cacheKey := fixedPathCache.GenerateCacheKey(r)
if _, err := fixedPathCache.Put(cacheKey, resp, body); err != nil {
log.Printf("[Cache] Failed to cache %s: %v", targetURL, err)
}
}
// 复制响应头 // 复制响应头
for key, values := range resp.Header { for key, values := range resp.Header {
for _, value := range values { for _, value := range values {
@ -121,14 +105,30 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
// 设置响应状态码 // 设置响应状态码
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
// 写入响应体 var written int64
written, err := w.Write(body) // 如果是GET请求且响应成功使用TeeReader同时写入缓存
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
cacheKey := fixedPathCache.GenerateCacheKey(r)
if cacheFile, err := fixedPathCache.CreateTemp(cacheKey, resp); err == nil {
defer cacheFile.Close()
teeReader := io.TeeReader(resp.Body, cacheFile)
written, err = io.Copy(w, teeReader)
if err == nil {
fixedPathCache.Commit(cacheKey, cacheFile.Name(), resp, written)
}
} else {
written, err = io.Copy(w, resp.Body)
}
} else {
written, err = io.Copy(w, resp.Body)
}
if err != nil && !isConnectionClosed(err) { if err != nil && !isConnectionClosed(err) {
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err) log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
} }
// 记录统计信息 // 记录统计信息
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
return return
} }