mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 16:41:54 +08:00
feat(cache): Implement streaming cache with TeeReader for efficient response handling
- Replace full response body reading with streaming cache mechanism - Add CreateTemp and Commit methods to CacheManager for incremental caching - Use io.TeeReader to simultaneously write response to client and cache file - Remove buffer pool and full body reading in proxy and mirror handlers - Improve memory efficiency and reduce latency for large responses - Update error handling and logging for cache-related operations
This commit is contained in:
parent
ffc64bb73a
commit
2267a27b37
55
internal/cache/manager.go
vendored
55
internal/cache/manager.go
vendored
@ -350,3 +350,58 @@ func (cm *CacheManager) ClearCache() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateTemp 创建临时缓存文件
|
||||||
|
func (cm *CacheManager) CreateTemp(key CacheKey, resp *http.Response) (*os.File, error) {
|
||||||
|
if !cm.enabled.Load() {
|
||||||
|
return nil, fmt.Errorf("cache is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建临时文件
|
||||||
|
tempFile, err := os.CreateTemp(cm.cacheDir, "temp-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit 提交缓存文件
|
||||||
|
func (cm *CacheManager) Commit(key CacheKey, tempPath string, resp *http.Response, size int64) error {
|
||||||
|
if !cm.enabled.Load() {
|
||||||
|
os.Remove(tempPath)
|
||||||
|
return fmt.Errorf("cache is disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成最终的缓存文件名
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(key.String()))
|
||||||
|
hashStr := hex.EncodeToString(h.Sum(nil))
|
||||||
|
ext := filepath.Ext(key.URL)
|
||||||
|
if ext == "" {
|
||||||
|
ext = ".bin"
|
||||||
|
}
|
||||||
|
filePath := filepath.Join(cm.cacheDir, hashStr+ext)
|
||||||
|
|
||||||
|
// 重命名临时文件
|
||||||
|
if err := os.Rename(tempPath, filePath); err != nil {
|
||||||
|
os.Remove(tempPath)
|
||||||
|
return fmt.Errorf("failed to rename temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建缓存项
|
||||||
|
item := &CacheItem{
|
||||||
|
FilePath: filePath,
|
||||||
|
ContentType: resp.Header.Get("Content-Type"),
|
||||||
|
Size: size,
|
||||||
|
LastAccess: time.Now(),
|
||||||
|
Hash: hashStr,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
AccessCount: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.items.Store(key, item)
|
||||||
|
cm.bytesSaved.Add(size)
|
||||||
|
log.Printf("[Cache] Cached %s (%s)", key.URL, formatBytes(size))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -144,22 +144,6 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 读取响应体
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Error reading response", http.StatusInternalServerError)
|
|
||||||
log.Printf("Error reading response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是GET请求且响应成功,尝试缓存
|
|
||||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
|
||||||
cacheKey := h.Cache.GenerateCacheKey(r)
|
|
||||||
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
|
|
||||||
log.Printf("[Cache] Failed to cache %s: %v", actualURL, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 复制响应头
|
// 复制响应头
|
||||||
copyHeader(w.Header(), resp.Header)
|
copyHeader(w.Header(), resp.Header)
|
||||||
w.Header().Set("Proxy-Go-Cache", "MISS")
|
w.Header().Set("Proxy-Go-Cache", "MISS")
|
||||||
@ -167,19 +151,38 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// 设置状态码
|
// 设置状态码
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
// 写入响应体
|
var written int64
|
||||||
written, err := w.Write(body)
|
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||||
if err != nil {
|
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
||||||
|
cacheKey := h.Cache.GenerateCacheKey(r)
|
||||||
|
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
|
||||||
|
defer cacheFile.Close()
|
||||||
|
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||||
|
written, err = io.Copy(w, teeReader)
|
||||||
|
if err == nil {
|
||||||
|
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
|
if err != nil && !isConnectionClosed(err) {
|
||||||
log.Printf("Error writing response: %v", err)
|
log.Printf("Error writing response: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
|
if err != nil && !isConnectionClosed(err) {
|
||||||
|
log.Printf("Error writing response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 记录访问日志
|
// 记录访问日志
|
||||||
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s",
|
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s",
|
||||||
r.Method, resp.StatusCode, time.Since(startTime),
|
r.Method, resp.StatusCode, time.Since(startTime),
|
||||||
utils.GetClientIP(r), utils.FormatBytes(int64(written)),
|
utils.GetClientIP(r), utils.FormatBytes(written),
|
||||||
utils.GetRequestSource(r), actualURL)
|
utils.GetRequestSource(r), actualURL)
|
||||||
|
|
||||||
// 记录统计信息
|
// 记录统计信息
|
||||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r)
|
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -14,16 +13,12 @@ import (
|
|||||||
"proxy-go/internal/metrics"
|
"proxy-go/internal/metrics"
|
||||||
"proxy-go/internal/utils"
|
"proxy-go/internal/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// 缓冲区大小
|
|
||||||
defaultBufferSize = 32 * 1024 // 32KB
|
|
||||||
|
|
||||||
// 超时时间常量
|
// 超时时间常量
|
||||||
clientConnTimeout = 10 * time.Second
|
clientConnTimeout = 10 * time.Second
|
||||||
proxyRespTimeout = 60 * time.Second
|
proxyRespTimeout = 60 * time.Second
|
||||||
@ -32,13 +27,6 @@ const (
|
|||||||
tlsHandshakeTimeout = 10 * time.Second
|
tlsHandshakeTimeout = 10 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// 统一的缓冲池
|
|
||||||
var bufferPool = sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return bytes.NewBuffer(make([]byte, defaultBufferSize))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加 hop-by-hop 头部映射
|
// 添加 hop-by-hop 头部映射
|
||||||
var hopHeadersMap = make(map[string]bool)
|
var hopHeadersMap = make(map[string]bool)
|
||||||
|
|
||||||
@ -145,49 +133,6 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler {
|
|||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// copyResponse 使用缓冲方式传输数据
|
|
||||||
func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) {
|
|
||||||
buf := bufferPool.Get().(*bytes.Buffer)
|
|
||||||
defer bufferPool.Put(buf)
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
var written int64
|
|
||||||
for {
|
|
||||||
// 清空缓冲区
|
|
||||||
buf.Reset()
|
|
||||||
|
|
||||||
// 读取数据到缓冲区
|
|
||||||
_, er := io.CopyN(buf, src, defaultBufferSize)
|
|
||||||
if er != nil && er != io.EOF {
|
|
||||||
return written, er
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果有数据,写入目标
|
|
||||||
if buf.Len() > 0 {
|
|
||||||
nw, ew := dst.Write(buf.Bytes())
|
|
||||||
if ew != nil {
|
|
||||||
return written, ew
|
|
||||||
}
|
|
||||||
written += int64(nw)
|
|
||||||
|
|
||||||
// 定期刷新缓冲区
|
|
||||||
if flusher != nil && written%(1024*1024) == 0 { // 每1MB刷新一次
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if er == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 最后一次刷新
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
return written, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// 添加 panic 恢复
|
// 添加 panic 恢复
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -356,37 +301,41 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 读取响应体到缓冲区
|
// 复制响应头
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
_, err = copyResponse(buf, resp.Body, nil)
|
|
||||||
if err != nil {
|
|
||||||
h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
body := buf.Bytes()
|
|
||||||
|
|
||||||
// 如果是GET请求且响应成功,尝试缓存
|
|
||||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
|
||||||
cacheKey := h.Cache.GenerateCacheKey(r)
|
|
||||||
if _, err := h.Cache.Put(cacheKey, resp, body); err != nil {
|
|
||||||
log.Printf("[Cache] Failed to cache %s: %v", r.URL.Path, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
copyHeader(w.Header(), resp.Header)
|
copyHeader(w.Header(), resp.Header)
|
||||||
w.Header().Set("Proxy-Go-Cache", "MISS")
|
w.Header().Set("Proxy-Go-Cache", "MISS")
|
||||||
w.Header().Del("Content-Security-Policy")
|
|
||||||
|
|
||||||
// 写入响应
|
// 设置响应状态码
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
n, err := w.Write(body)
|
|
||||||
|
var written int64
|
||||||
|
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||||
|
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && h.Cache != nil {
|
||||||
|
cacheKey := h.Cache.GenerateCacheKey(r)
|
||||||
|
if cacheFile, err := h.Cache.CreateTemp(cacheKey, resp); err == nil {
|
||||||
|
defer cacheFile.Close()
|
||||||
|
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||||
|
written, err = io.Copy(w, teeReader)
|
||||||
|
if err == nil {
|
||||||
|
h.Cache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
if err != nil && !isConnectionClosed(err) {
|
if err != nil && !isConnectionClosed(err) {
|
||||||
log.Printf("Error writing response: %v", err)
|
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
|
if err != nil && !isConnectionClosed(err) {
|
||||||
|
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录访问日志
|
// 记录统计信息
|
||||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(n), utils.GetClientIP(r), r)
|
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, utils.GetClientIP(r), r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
func copyHeader(dst, src http.Header) {
|
||||||
|
@ -94,22 +94,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 读取响应体
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Error reading response", http.StatusInternalServerError)
|
|
||||||
log.Printf("Error reading response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是GET请求且响应成功,尝试缓存
|
|
||||||
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
|
|
||||||
cacheKey := fixedPathCache.GenerateCacheKey(r)
|
|
||||||
if _, err := fixedPathCache.Put(cacheKey, resp, body); err != nil {
|
|
||||||
log.Printf("[Cache] Failed to cache %s: %v", targetURL, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 复制响应头
|
// 复制响应头
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
@ -121,14 +105,30 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
// 设置响应状态码
|
// 设置响应状态码
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
// 写入响应体
|
var written int64
|
||||||
written, err := w.Write(body)
|
// 如果是GET请求且响应成功,使用TeeReader同时写入缓存
|
||||||
|
if r.Method == http.MethodGet && resp.StatusCode == http.StatusOK && fixedPathCache != nil {
|
||||||
|
cacheKey := fixedPathCache.GenerateCacheKey(r)
|
||||||
|
if cacheFile, err := fixedPathCache.CreateTemp(cacheKey, resp); err == nil {
|
||||||
|
defer cacheFile.Close()
|
||||||
|
teeReader := io.TeeReader(resp.Body, cacheFile)
|
||||||
|
written, err = io.Copy(w, teeReader)
|
||||||
|
if err == nil {
|
||||||
|
fixedPathCache.Commit(cacheKey, cacheFile.Name(), resp, written)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
written, err = io.Copy(w, resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil && !isConnectionClosed(err) {
|
if err != nil && !isConnectionClosed(err) {
|
||||||
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
log.Printf("[%s] Error writing response: %v", utils.GetClientIP(r), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录统计信息
|
// 记录统计信息
|
||||||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), int64(written), utils.GetClientIP(r), r)
|
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user