mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 08:31:55 +08:00
feat(cache): Implement streaming cache with TeeReader for efficient response handling
- Replace full response body reading with streaming cache mechanism - Add CreateTemp and Commit methods to CacheManager for incremental caching - Use io.TeeReader to simultaneously write response to client and cache file - Remove buffer pool and full body reading in proxy and mirror handlers - Improve memory efficiency and reduce latency for large responses - Update error handling and logging for cache-related operations
This commit is contained in:
parent
ffc64bb73a
commit
2267a27b37
55
internal/cache/manager.go
vendored
55
internal/cache/manager.go
vendored
@ -350,3 +350,58 @@ func (cm *CacheManager) ClearCache() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTemp 创建临时缓存文件
|
||||
func (cm *CacheManager) CreateTemp(key CacheKey, resp *http.Response) (*os.File, error) {
|
||||
if !cm.enabled.Load() {
|
||||
return nil, fmt.Errorf("cache is disabled")
|
||||
}
|
||||
|
||||
// 创建临时文件
|
||||
tempFile, err := os.CreateTemp(cm.cacheDir, "temp-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
return tempFile, nil
|
||||
}
|
||||
|
||||
// Commit 提交缓存文件
|
||||
func (cm *CacheManager) Commit(key CacheKey, tempPath string, resp *http.Response, size int64) error {
|
||||
if !cm.enabled.Load() {
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("cache is disabled")
|
||||
}
|
||||
|
||||
// 生成最终的缓存文件名
|
||||
h := sha256.New()
|
||||
h.Write([]byte(key.String()))
|
||||
hashStr := hex.EncodeToString(h.Sum(nil))
|
||||
ext := filepath.Ext(key.URL)
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
filePath := filepath.Join(cm.cacheDir, hashStr+ext)
|
||||
|
||||
// 重命名临时文件
|
||||
if err := os.Rename(tempPath, filePath); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return fmt.Errorf("failed to rename temp file: %v", err)
|
||||
}
|
||||
|
||||
// 创建缓存项
|
||||
item := &CacheItem{
|
||||
FilePath: filePath,
|
||||
ContentType: resp.Header.Get("Content-Type"),
|
||||
Size: size,
|
||||
LastAccess: time.Now(),
|
||||
Hash: hashStr,
|
||||
CreatedAt: time.Now(),
|
||||
AccessCount: 1,
|
||||
}
|
||||
|
||||
cm.items.Store(key, item)
|
||||
cm.bytesSaved.Add(size)
|
||||
log.Printf("[Cache] Cached %s (%s)", key.URL, formatBytes(size))
|
||||
return nil
|
||||
}
|
||||
|
@ -144,22 +144,6 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Error reading response", http.StatusInternalServerError)
|
||||
log.Printf("Error reading response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是GET请求且响应成功,尝试缓存
|
||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
||||
cacheKey := h.Cache.GenerateCacheKey(r)
|
||||
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
|
||||
log.Printf("[Cache] Failed to cache %s: %v", actualURL, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 复制响应头
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.Header().Set("Proxy-Go-Cache", "MISS")
|
||||
@ -167,19 +151,38 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// 设置状态码
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 写入响应体
|
||||
written, err := w.Write(body)
|
||||
if err != nil {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
return
|
||||
var written int64
|
||||
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
||||
cacheKey := h.Cache.GenerateCacheKey(r)
|
||||
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
|
||||
defer cacheFile.Close()
|
||||
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||
written, err = io.Copy(w, teeReader)
|
||||
if err == nil {
|
||||
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 记录访问日志
|
||||
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s",
|
||||
r.Method, resp.StatusCode, time.Since(startTime),
|
||||
utils.GetClientIP(r), utils.FormatBytes(int64(written)),
|
||||
utils.GetClientIP(r), utils.FormatBytes(written),
|
||||
utils.GetRequestSource(r), actualURL)
|
||||
|
||||
// 记录统计信息
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r)
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -14,16 +13,12 @@ import (
|
||||
"proxy-go/internal/metrics"
|
||||
"proxy-go/internal/utils"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
// 缓冲区大小
|
||||
defaultBufferSize = 32 * 1024 // 32KB
|
||||
|
||||
// 超时时间常量
|
||||
clientConnTimeout = 10 * time.Second
|
||||
proxyRespTimeout = 60 * time.Second
|
||||
@ -32,13 +27,6 @@ const (
|
||||
tlsHandshakeTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// 统一的缓冲池
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, defaultBufferSize))
|
||||
},
|
||||
}
|
||||
|
||||
// 添加 hop-by-hop 头部映射
|
||||
var hopHeadersMap = make(map[string]bool)
|
||||
|
||||
@ -145,49 +133,6 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler {
|
||||
return handler
|
||||
}
|
||||
|
||||
// copyResponse 使用缓冲方式传输数据
|
||||
func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) {
|
||||
buf := bufferPool.Get().(*bytes.Buffer)
|
||||
defer bufferPool.Put(buf)
|
||||
buf.Reset()
|
||||
|
||||
var written int64
|
||||
for {
|
||||
// 清空缓冲区
|
||||
buf.Reset()
|
||||
|
||||
// 读取数据到缓冲区
|
||||
_, er := io.CopyN(buf, src, defaultBufferSize)
|
||||
if er != nil && er != io.EOF {
|
||||
return written, er
|
||||
}
|
||||
|
||||
// 如果有数据,写入目标
|
||||
if buf.Len() > 0 {
|
||||
nw, ew := dst.Write(buf.Bytes())
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
written += int64(nw)
|
||||
|
||||
// 定期刷新缓冲区
|
||||
if flusher != nil && written%(1024*1024) == 0 { // 每1MB刷新一次
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if er == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 最后一次刷新
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// 添加 panic 恢复
|
||||
defer func() {
|
||||
@ -356,37 +301,41 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体到缓冲区
|
||||
buf := new(bytes.Buffer)
|
||||
_, err = copyResponse(buf, resp.Body, nil)
|
||||
if err != nil {
|
||||
h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err))
|
||||
return
|
||||
}
|
||||
body := buf.Bytes()
|
||||
// 复制响应头
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.Header().Set("Proxy-Go-Cache", "MISS")
|
||||
|
||||
// 如果是GET请求且响应成功,尝试缓存
|
||||
// 设置响应状态码
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
var written int64
|
||||
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
||||
cacheKey := h.Cache.GenerateCacheKey(r)
|
||||
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
|
||||
log.Printf("[Cache] Failed to cache %s: %v", r.URL.Path, err)
|
||||
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
|
||||
defer cacheFile.Close()
|
||||
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||
written, err = io.Copy(w, teeReader)
|
||||
if err == nil {
|
||||
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.Header().Set("Proxy-Go-Cache", "MISS")
|
||||
w.Header().Del("Content-Security-Policy")
|
||||
|
||||
// 写入响应
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
n, err := w.Write(body)
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
}
|
||||
|
||||
// 记录访问日志
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(n), utils.GetClientIP(r), r)
|
||||
// 记录统计信息
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, utils.GetClientIP(r), r)
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
|
@ -94,22 +94,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Error reading response", http.StatusInternalServerError)
|
||||
log.Printf("Error reading response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是GET请求且响应成功,尝试缓存
|
||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
|
||||
cacheKey := fixedPathCache.GenerateCacheKey(r)
|
||||
if _, err := fixedPathCache.Put(cacheKey, resp, body); err != nil {
|
||||
log.Printf("[Cache] Failed to cache %s: %v", targetURL, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 复制响应头
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
@ -121,14 +105,30 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
||||
// 设置响应状态码
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 写入响应体
|
||||
written, err := w.Write(body)
|
||||
var written int64
|
||||
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
|
||||
cacheKey := fixedPathCache.GenerateCacheKey(r)
|
||||
if cacheFile, err := fixedPathCache.CreateTemp(cacheKey, resp); err == nil {
|
||||
defer cacheFile.Close()
|
||||
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||
written, err = io.Copy(w, teeReader)
|
||||
if err == nil {
|
||||
fixedPathCache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
}
|
||||
} else {
|
||||
written, err = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
if err != nil && !isConnectionClosed(err) {
|
||||
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||
}
|
||||
|
||||
// 记录统计信息
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r)
|
||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user