refactor(proxy): Optimize hop-by-hop header handling and routing logic

- Simplify hop-by-hop headers initialization using map literal
- Create a local copy of hop headers to improve header filtering
- Enhance routing logic with context-based timeout for alternative target checks
- Improve error handling and logging in file size and routing detection
- Reduce unnecessary goroutine complexity in target URL selection
This commit is contained in:
wood chen 2025-02-18 13:34:50 +08:00
parent ec07ae094e
commit 929d13157d
2 changed files with 74 additions and 74 deletions

View File

@ -28,23 +28,20 @@ const (
)
// 添加 hop-by-hop 头部映射
var hopHeadersMap = make(map[string]bool)
var hopHeadersBase = map[string]bool{
"Connection": true,
"Keep-Alive": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
"Te": true,
"Trailer": true,
"Transfer-Encoding": true,
"Upgrade": true,
}
func init() {
headers := []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Proxy-Connection",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
for _, h := range headers {
hopHeadersMap[h] = true
}
// 移除旧的初始化代码,因为我们直接在 map 字面量中定义了所有值
}
// ErrorHandler 定义错误处理函数类型
@ -337,16 +334,22 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func copyHeader(dst, src http.Header) {
// 创建一个新的局部 map复制基础 hop headers
hopHeaders := make(map[string]bool, len(hopHeadersBase))
for k, v := range hopHeadersBase {
hopHeaders[k] = v
}
// 处理 Connection 头部指定的其他 hop-by-hop 头部
if connection := src.Get("Connection"); connection != "" {
for _, h := range strings.Split(connection, ",") {
hopHeadersMap[strings.TrimSpace(h)] = true
hopHeaders[strings.TrimSpace(h)] = true
}
}
// 使用 map 快速查找,跳过 hop-by-hop 头部
// 使用局部 map 快速查找,跳过 hop-by-hop 头部
for k, vv := range src {
if !hopHeadersMap[k] {
if !hopHeaders[k] {
for _, v := range vv {
dst.Add(k, v)
}

View File

@ -187,88 +187,85 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo
// 默认使用默认目标
targetBase := pathConfig.DefaultTarget
// 如果没有设置最小阈值,使用默认值 500KB
minThreshold := pathConfig.SizeThreshold
if minThreshold <= 0 {
minThreshold = 500 * 1024
}
// 如果没有设置最大阈值,使用默认值 10MB
maxThreshold := pathConfig.MaxSize
if maxThreshold <= 0 {
maxThreshold = 10 * 1024 * 1024
}
// 检查文件扩展名
// 如果配置了扩展名映射
if pathConfig.ExtensionMap != nil {
ext := strings.ToLower(filepath.Ext(path))
if ext != "" {
ext = ext[1:] // 移除开头的点
// 检查是否在扩展名映射中
// 检查是否在扩展名映射中
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
// 使用 channel 来并发获取文件大小和检查可访问性
type result struct {
size int64
accessible bool
err error
}
defaultChan := make(chan result, 1)
altChan := make(chan result, 1)
// 并发检查默认源和备用源
go func() {
size, err := GetFileSize(client, targetBase+path)
defaultChan <- result{size: size, err: err}
}()
go func() {
accessible := isTargetAccessible(client, altTarget+path)
altChan <- result{accessible: accessible}
}()
// 获取默认源结果
defaultResult := <-defaultChan
if defaultResult.err != nil {
log.Printf("[FileSize] Failed to get size from default source for %s: %v", path, defaultResult.err)
// 检查文件大小
contentLength, err := GetFileSize(client, targetBase+path)
if err != nil {
log.Printf("[Route] %s -> %s (error getting size: %v)", path, targetBase, err)
return targetBase
}
contentLength := defaultResult.size
log.Printf("[FileSize] Path: %s, Size: %s (from default source)",
path, FormatBytes(contentLength))
// 检查文件大小是否在阈值范围内
// 如果没有设置最小阈值,使用默认值 500KB
minThreshold := pathConfig.SizeThreshold
if minThreshold <= 0 {
minThreshold = 500 * 1024
}
// 如果没有设置最大阈值,使用默认值 10MB
maxThreshold := pathConfig.MaxSize
if maxThreshold <= 0 {
maxThreshold = 10 * 1024 * 1024
}
if contentLength > minThreshold && contentLength <= maxThreshold {
// 获取备用源检查结果
altResult := <-altChan
if altResult.accessible {
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
path, altTarget, FormatBytes(contentLength),
FormatBytes(minThreshold), FormatBytes(maxThreshold))
return altTarget
} else {
// 创建一个带超时的 context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 使用 channel 来接收备用源检查结果
altChan := make(chan struct {
accessible bool
err error
}, 1)
// 在 goroutine 中检查备用源可访问性
go func() {
accessible := isTargetAccessible(client, altTarget+path)
select {
case altChan <- struct {
accessible bool
err error
}{accessible: accessible}:
case <-ctx.Done():
// context 已取消,不需要发送结果
}
}()
// 等待结果或超时
select {
case result := <-altChan:
if result.accessible {
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
path, altTarget, FormatBytes(contentLength),
FormatBytes(minThreshold), FormatBytes(maxThreshold))
return altTarget
}
log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)",
path, targetBase)
case <-ctx.Done():
log.Printf("[Route] %s -> %s (fallback: alternative target check timeout)",
path, targetBase)
}
} else if contentLength <= minThreshold {
// 如果文件大小不合适,直接丢弃备用源检查结果
go func() { <-altChan }()
log.Printf("[Route] %s -> %s (size: %s <= %s)",
path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold))
} else {
// 如果文件大小不合适,直接丢弃备用源检查结果
go func() { <-altChan }()
log.Printf("[Route] %s -> %s (size: %s > %s)",
path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold))
}
} else {
// 记录没有匹配扩展名映射的情况
log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase)
}
} else {
// 记录没有扩展名的情况
log.Printf("[Route] %s -> %s (no extension)", path, targetBase)
}
} else {
// 记录没有扩展名映射配置的情况
log.Printf("[Route] %s -> %s (no extension map)", path, targetBase)
}