diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index 4a78310..9b148d1 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -28,23 +28,20 @@ const ( ) // 添加 hop-by-hop 头部映射 -var hopHeadersMap = make(map[string]bool) +var hopHeadersBase = map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "Proxy-Connection": true, + "Te": true, + "Trailer": true, + "Transfer-Encoding": true, + "Upgrade": true, +} func init() { - headers := []string{ - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Proxy-Connection", - "Te", - "Trailer", - "Transfer-Encoding", - "Upgrade", - } - for _, h := range headers { - hopHeadersMap[h] = true - } + // 移除旧的初始化代码,因为我们直接在 map 字面量中定义了所有值 } // ErrorHandler 定义错误处理函数类型 @@ -337,16 +334,22 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func copyHeader(dst, src http.Header) { + // 创建一个新的局部 map,复制基础 hop headers + hopHeaders := make(map[string]bool, len(hopHeadersBase)) + for k, v := range hopHeadersBase { + hopHeaders[k] = v + } + // 处理 Connection 头部指定的其他 hop-by-hop 头部 if connection := src.Get("Connection"); connection != "" { for _, h := range strings.Split(connection, ",") { - hopHeadersMap[strings.TrimSpace(h)] = true + hopHeaders[strings.TrimSpace(h)] = true } } - // 使用 map 快速查找,跳过 hop-by-hop 头部 + // 使用局部 map 快速查找,跳过 hop-by-hop 头部 for k, vv := range src { - if !hopHeadersMap[k] { + if !hopHeaders[k] { for _, v := range vv { dst.Add(k, v) } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 221da71..55a7181 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -187,88 +187,85 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo // 默认使用默认目标 targetBase := pathConfig.DefaultTarget - // 如果没有设置最小阈值,使用默认值 500KB - minThreshold := pathConfig.SizeThreshold - if minThreshold <= 0 { - minThreshold = 500 * 1024 - } - - // 如果没有设置最大阈值,使用默认值 10MB - maxThreshold := pathConfig.MaxSize - if maxThreshold <= 0 { - maxThreshold = 10 * 1024 * 1024 - } - - // 检查文件扩展名 + // 如果配置了扩展名映射 if pathConfig.ExtensionMap != nil { ext := strings.ToLower(filepath.Ext(path)) if ext != "" { ext = ext[1:] // 移除开头的点 - // 先检查是否在扩展名映射中 + // 检查是否在扩展名映射中 if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists { - // 使用 channel 来并发获取文件大小和检查可访问性 - type result struct { - size int64 - accessible bool - err error - } - defaultChan := make(chan result, 1) - altChan := make(chan result, 1) - - // 并发检查默认源和备用源 - go func() { - size, err := GetFileSize(client, targetBase+path) - defaultChan <- result{size: size, err: err} - }() - go func() { - accessible := isTargetAccessible(client, altTarget+path) - altChan <- result{accessible: accessible} - }() - - // 获取默认源结果 - defaultResult := <-defaultChan - if defaultResult.err != nil { - log.Printf("[FileSize] Failed to get size from default source for %s: %v", path, defaultResult.err) + // 检查文件大小 + contentLength, err := GetFileSize(client, targetBase+path) + if err != nil { + log.Printf("[Route] %s -> %s (error getting size: %v)", path, targetBase, err) return targetBase } - contentLength := defaultResult.size - log.Printf("[FileSize] Path: %s, Size: %s (from default source)", - path, FormatBytes(contentLength)) - // 检查文件大小是否在阈值范围内 + // 如果没有设置最小阈值,使用默认值 500KB + minThreshold := pathConfig.SizeThreshold + if minThreshold <= 0 { + minThreshold = 500 * 1024 + } + + // 如果没有设置最大阈值,使用默认值 10MB + maxThreshold := pathConfig.MaxSize + if maxThreshold <= 0 { + maxThreshold = 10 * 1024 * 1024 + } + if contentLength > minThreshold && contentLength <= maxThreshold { - // 获取备用源检查结果 - altResult := <-altChan - if altResult.accessible { - log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)", - path, altTarget, FormatBytes(contentLength), - FormatBytes(minThreshold), FormatBytes(maxThreshold)) - return altTarget - } else { + // 创建一个带超时的 context + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用 channel 来接收备用源检查结果 + altChan := make(chan struct { + accessible bool + err error + }, 1) + + // 在 goroutine 中检查备用源可访问性 + go func() { + accessible := isTargetAccessible(client, altTarget+path) + select { + case altChan <- struct { + accessible bool + err error + }{accessible: accessible}: + case <-ctx.Done(): + // context 已取消,不需要发送结果 + } + }() + + // 等待结果或超时 + select { + case result := <-altChan: + if result.accessible { + log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)", + path, altTarget, FormatBytes(contentLength), + FormatBytes(minThreshold), FormatBytes(maxThreshold)) + return altTarget + } log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)", path, targetBase) + case <-ctx.Done(): + log.Printf("[Route] %s -> %s (fallback: alternative target check timeout)", + path, targetBase) } } else if contentLength <= minThreshold { - // 如果文件大小不合适,直接丢弃备用源检查结果 - go func() { <-altChan }() log.Printf("[Route] %s -> %s (size: %s <= %s)", path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold)) } else { - // 如果文件大小不合适,直接丢弃备用源检查结果 - go func() { <-altChan }() log.Printf("[Route] %s -> %s (size: %s > %s)", path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold)) } } else { - // 记录没有匹配扩展名映射的情况 log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase) } } else { - // 记录没有扩展名的情况 log.Printf("[Route] %s -> %s (no extension)", path, targetBase) } } else { - // 记录没有扩展名映射配置的情况 log.Printf("[Route] %s -> %s (no extension map)", path, targetBase) }