mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 08:31:55 +08:00
refactor(proxy): Optimize hop-by-hop header handling and routing logic
- Simplify hop-by-hop headers initialization using map literal - Create a local copy of hop headers to improve header filtering - Enhance routing logic with context-based timeout for alternative target checks - Improve error handling and logging in file size and routing detection - Reduce unnecessary goroutine complexity in target URL selection
This commit is contained in:
parent
ec07ae094e
commit
929d13157d
@ -28,23 +28,20 @@ const (
|
||||
)
|
||||
|
||||
// 添加 hop-by-hop 头部映射
|
||||
var hopHeadersMap = make(map[string]bool)
|
||||
var hopHeadersBase = map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Proxy-Authenticate": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Proxy-Connection": true,
|
||||
"Te": true,
|
||||
"Trailer": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
headers := []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Proxy-Connection",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
for _, h := range headers {
|
||||
hopHeadersMap[h] = true
|
||||
}
|
||||
// 移除旧的初始化代码,因为我们直接在 map 字面量中定义了所有值
|
||||
}
|
||||
|
||||
// ErrorHandler 定义错误处理函数类型
|
||||
@ -337,16 +334,22 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
// 创建一个新的局部 map,复制基础 hop headers
|
||||
hopHeaders := make(map[string]bool, len(hopHeadersBase))
|
||||
for k, v := range hopHeadersBase {
|
||||
hopHeaders[k] = v
|
||||
}
|
||||
|
||||
// 处理 Connection 头部指定的其他 hop-by-hop 头部
|
||||
if connection := src.Get("Connection"); connection != "" {
|
||||
for _, h := range strings.Split(connection, ",") {
|
||||
hopHeadersMap[strings.TrimSpace(h)] = true
|
||||
hopHeaders[strings.TrimSpace(h)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 map 快速查找,跳过 hop-by-hop 头部
|
||||
// 使用局部 map 快速查找,跳过 hop-by-hop 头部
|
||||
for k, vv := range src {
|
||||
if !hopHeadersMap[k] {
|
||||
if !hopHeaders[k] {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
|
@ -187,6 +187,20 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo
|
||||
// 默认使用默认目标
|
||||
targetBase := pathConfig.DefaultTarget
|
||||
|
||||
// 如果配置了扩展名映射
|
||||
if pathConfig.ExtensionMap != nil {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if ext != "" {
|
||||
ext = ext[1:] // 移除开头的点
|
||||
// 检查是否在扩展名映射中
|
||||
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
|
||||
// 检查文件大小
|
||||
contentLength, err := GetFileSize(client, targetBase+path)
|
||||
if err != nil {
|
||||
log.Printf("[Route] %s -> %s (error getting size: %v)", path, targetBase, err)
|
||||
return targetBase
|
||||
}
|
||||
|
||||
// 如果没有设置最小阈值,使用默认值 500KB
|
||||
minThreshold := pathConfig.SizeThreshold
|
||||
if minThreshold <= 0 {
|
||||
@ -199,76 +213,59 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo
|
||||
maxThreshold = 10 * 1024 * 1024
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
if pathConfig.ExtensionMap != nil {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if ext != "" {
|
||||
ext = ext[1:] // 移除开头的点
|
||||
// 先检查是否在扩展名映射中
|
||||
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
|
||||
// 使用 channel 来并发获取文件大小和检查可访问性
|
||||
type result struct {
|
||||
size int64
|
||||
if contentLength > minThreshold && contentLength <= maxThreshold {
|
||||
// 创建一个带超时的 context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用 channel 来接收备用源检查结果
|
||||
altChan := make(chan struct {
|
||||
accessible bool
|
||||
err error
|
||||
}
|
||||
defaultChan := make(chan result, 1)
|
||||
altChan := make(chan result, 1)
|
||||
}, 1)
|
||||
|
||||
// 并发检查默认源和备用源
|
||||
go func() {
|
||||
size, err := GetFileSize(client, targetBase+path)
|
||||
defaultChan <- result{size: size, err: err}
|
||||
}()
|
||||
// 在 goroutine 中检查备用源可访问性
|
||||
go func() {
|
||||
accessible := isTargetAccessible(client, altTarget+path)
|
||||
altChan <- result{accessible: accessible}
|
||||
select {
|
||||
case altChan <- struct {
|
||||
accessible bool
|
||||
err error
|
||||
}{accessible: accessible}:
|
||||
case <-ctx.Done():
|
||||
// context 已取消,不需要发送结果
|
||||
}
|
||||
}()
|
||||
|
||||
// 获取默认源结果
|
||||
defaultResult := <-defaultChan
|
||||
if defaultResult.err != nil {
|
||||
log.Printf("[FileSize] Failed to get size from default source for %s: %v", path, defaultResult.err)
|
||||
return targetBase
|
||||
}
|
||||
contentLength := defaultResult.size
|
||||
log.Printf("[FileSize] Path: %s, Size: %s (from default source)",
|
||||
path, FormatBytes(contentLength))
|
||||
|
||||
// 检查文件大小是否在阈值范围内
|
||||
if contentLength > minThreshold && contentLength <= maxThreshold {
|
||||
// 获取备用源检查结果
|
||||
altResult := <-altChan
|
||||
if altResult.accessible {
|
||||
// 等待结果或超时
|
||||
select {
|
||||
case result := <-altChan:
|
||||
if result.accessible {
|
||||
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
|
||||
path, altTarget, FormatBytes(contentLength),
|
||||
FormatBytes(minThreshold), FormatBytes(maxThreshold))
|
||||
return altTarget
|
||||
} else {
|
||||
}
|
||||
log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)",
|
||||
path, targetBase)
|
||||
case <-ctx.Done():
|
||||
log.Printf("[Route] %s -> %s (fallback: alternative target check timeout)",
|
||||
path, targetBase)
|
||||
}
|
||||
} else if contentLength <= minThreshold {
|
||||
// 如果文件大小不合适,直接丢弃备用源检查结果
|
||||
go func() { <-altChan }()
|
||||
log.Printf("[Route] %s -> %s (size: %s <= %s)",
|
||||
path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold))
|
||||
} else {
|
||||
// 如果文件大小不合适,直接丢弃备用源检查结果
|
||||
go func() { <-altChan }()
|
||||
log.Printf("[Route] %s -> %s (size: %s > %s)",
|
||||
path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold))
|
||||
}
|
||||
} else {
|
||||
// 记录没有匹配扩展名映射的情况
|
||||
log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase)
|
||||
}
|
||||
} else {
|
||||
// 记录没有扩展名的情况
|
||||
log.Printf("[Route] %s -> %s (no extension)", path, targetBase)
|
||||
}
|
||||
} else {
|
||||
// 记录没有扩展名映射配置的情况
|
||||
log.Printf("[Route] %s -> %s (no extension map)", path, targetBase)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user