mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 16:41:54 +08:00
refactor(proxy): Optimize hop-by-hop header handling and routing logic
- Simplify hop-by-hop headers initialization using map literal - Create a local copy of hop headers to improve header filtering - Enhance routing logic with context-based timeout for alternative target checks - Improve error handling and logging in file size and routing detection - Reduce unnecessary goroutine complexity in target URL selection
This commit is contained in:
parent
ec07ae094e
commit
929d13157d
@ -28,23 +28,20 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// 添加 hop-by-hop 头部映射
|
// 添加 hop-by-hop 头部映射
|
||||||
var hopHeadersMap = make(map[string]bool)
|
var hopHeadersBase = map[string]bool{
|
||||||
|
"Connection": true,
|
||||||
|
"Keep-Alive": true,
|
||||||
|
"Proxy-Authenticate": true,
|
||||||
|
"Proxy-Authorization": true,
|
||||||
|
"Proxy-Connection": true,
|
||||||
|
"Te": true,
|
||||||
|
"Trailer": true,
|
||||||
|
"Transfer-Encoding": true,
|
||||||
|
"Upgrade": true,
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
headers := []string{
|
// 移除旧的初始化代码,因为我们直接在 map 字面量中定义了所有值
|
||||||
"Connection",
|
|
||||||
"Keep-Alive",
|
|
||||||
"Proxy-Authenticate",
|
|
||||||
"Proxy-Authorization",
|
|
||||||
"Proxy-Connection",
|
|
||||||
"Te",
|
|
||||||
"Trailer",
|
|
||||||
"Transfer-Encoding",
|
|
||||||
"Upgrade",
|
|
||||||
}
|
|
||||||
for _, h := range headers {
|
|
||||||
hopHeadersMap[h] = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorHandler 定义错误处理函数类型
|
// ErrorHandler 定义错误处理函数类型
|
||||||
@ -337,16 +334,22 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
func copyHeader(dst, src http.Header) {
|
||||||
|
// 创建一个新的局部 map,复制基础 hop headers
|
||||||
|
hopHeaders := make(map[string]bool, len(hopHeadersBase))
|
||||||
|
for k, v := range hopHeadersBase {
|
||||||
|
hopHeaders[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
// 处理 Connection 头部指定的其他 hop-by-hop 头部
|
// 处理 Connection 头部指定的其他 hop-by-hop 头部
|
||||||
if connection := src.Get("Connection"); connection != "" {
|
if connection := src.Get("Connection"); connection != "" {
|
||||||
for _, h := range strings.Split(connection, ",") {
|
for _, h := range strings.Split(connection, ",") {
|
||||||
hopHeadersMap[strings.TrimSpace(h)] = true
|
hopHeaders[strings.TrimSpace(h)] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 map 快速查找,跳过 hop-by-hop 头部
|
// 使用局部 map 快速查找,跳过 hop-by-hop 头部
|
||||||
for k, vv := range src {
|
for k, vv := range src {
|
||||||
if !hopHeadersMap[k] {
|
if !hopHeaders[k] {
|
||||||
for _, v := range vv {
|
for _, v := range vv {
|
||||||
dst.Add(k, v)
|
dst.Add(k, v)
|
||||||
}
|
}
|
||||||
|
@ -187,88 +187,85 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo
|
|||||||
// 默认使用默认目标
|
// 默认使用默认目标
|
||||||
targetBase := pathConfig.DefaultTarget
|
targetBase := pathConfig.DefaultTarget
|
||||||
|
|
||||||
// 如果没有设置最小阈值,使用默认值 500KB
|
// 如果配置了扩展名映射
|
||||||
minThreshold := pathConfig.SizeThreshold
|
|
||||||
if minThreshold <= 0 {
|
|
||||||
minThreshold = 500 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果没有设置最大阈值,使用默认值 10MB
|
|
||||||
maxThreshold := pathConfig.MaxSize
|
|
||||||
if maxThreshold <= 0 {
|
|
||||||
maxThreshold = 10 * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查文件扩展名
|
|
||||||
if pathConfig.ExtensionMap != nil {
|
if pathConfig.ExtensionMap != nil {
|
||||||
ext := strings.ToLower(filepath.Ext(path))
|
ext := strings.ToLower(filepath.Ext(path))
|
||||||
if ext != "" {
|
if ext != "" {
|
||||||
ext = ext[1:] // 移除开头的点
|
ext = ext[1:] // 移除开头的点
|
||||||
// 先检查是否在扩展名映射中
|
// 检查是否在扩展名映射中
|
||||||
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
|
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
|
||||||
// 使用 channel 来并发获取文件大小和检查可访问性
|
// 检查文件大小
|
||||||
type result struct {
|
contentLength, err := GetFileSize(client, targetBase+path)
|
||||||
size int64
|
if err != nil {
|
||||||
accessible bool
|
log.Printf("[Route] %s -> %s (error getting size: %v)", path, targetBase, err)
|
||||||
err error
|
|
||||||
}
|
|
||||||
defaultChan := make(chan result, 1)
|
|
||||||
altChan := make(chan result, 1)
|
|
||||||
|
|
||||||
// 并发检查默认源和备用源
|
|
||||||
go func() {
|
|
||||||
size, err := GetFileSize(client, targetBase+path)
|
|
||||||
defaultChan <- result{size: size, err: err}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
accessible := isTargetAccessible(client, altTarget+path)
|
|
||||||
altChan <- result{accessible: accessible}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 获取默认源结果
|
|
||||||
defaultResult := <-defaultChan
|
|
||||||
if defaultResult.err != nil {
|
|
||||||
log.Printf("[FileSize] Failed to get size from default source for %s: %v", path, defaultResult.err)
|
|
||||||
return targetBase
|
return targetBase
|
||||||
}
|
}
|
||||||
contentLength := defaultResult.size
|
|
||||||
log.Printf("[FileSize] Path: %s, Size: %s (from default source)",
|
|
||||||
path, FormatBytes(contentLength))
|
|
||||||
|
|
||||||
// 检查文件大小是否在阈值范围内
|
// 如果没有设置最小阈值,使用默认值 500KB
|
||||||
|
minThreshold := pathConfig.SizeThreshold
|
||||||
|
if minThreshold <= 0 {
|
||||||
|
minThreshold = 500 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有设置最大阈值,使用默认值 10MB
|
||||||
|
maxThreshold := pathConfig.MaxSize
|
||||||
|
if maxThreshold <= 0 {
|
||||||
|
maxThreshold = 10 * 1024 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
if contentLength > minThreshold && contentLength <= maxThreshold {
|
if contentLength > minThreshold && contentLength <= maxThreshold {
|
||||||
// 获取备用源检查结果
|
// 创建一个带超时的 context
|
||||||
altResult := <-altChan
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
if altResult.accessible {
|
defer cancel()
|
||||||
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
|
|
||||||
path, altTarget, FormatBytes(contentLength),
|
// 使用 channel 来接收备用源检查结果
|
||||||
FormatBytes(minThreshold), FormatBytes(maxThreshold))
|
altChan := make(chan struct {
|
||||||
return altTarget
|
accessible bool
|
||||||
} else {
|
err error
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
// 在 goroutine 中检查备用源可访问性
|
||||||
|
go func() {
|
||||||
|
accessible := isTargetAccessible(client, altTarget+path)
|
||||||
|
select {
|
||||||
|
case altChan <- struct {
|
||||||
|
accessible bool
|
||||||
|
err error
|
||||||
|
}{accessible: accessible}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// context 已取消,不需要发送结果
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待结果或超时
|
||||||
|
select {
|
||||||
|
case result := <-altChan:
|
||||||
|
if result.accessible {
|
||||||
|
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
|
||||||
|
path, altTarget, FormatBytes(contentLength),
|
||||||
|
FormatBytes(minThreshold), FormatBytes(maxThreshold))
|
||||||
|
return altTarget
|
||||||
|
}
|
||||||
log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)",
|
log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)",
|
||||||
path, targetBase)
|
path, targetBase)
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("[Route] %s -> %s (fallback: alternative target check timeout)",
|
||||||
|
path, targetBase)
|
||||||
}
|
}
|
||||||
} else if contentLength <= minThreshold {
|
} else if contentLength <= minThreshold {
|
||||||
// 如果文件大小不合适,直接丢弃备用源检查结果
|
|
||||||
go func() { <-altChan }()
|
|
||||||
log.Printf("[Route] %s -> %s (size: %s <= %s)",
|
log.Printf("[Route] %s -> %s (size: %s <= %s)",
|
||||||
path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold))
|
path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold))
|
||||||
} else {
|
} else {
|
||||||
// 如果文件大小不合适,直接丢弃备用源检查结果
|
|
||||||
go func() { <-altChan }()
|
|
||||||
log.Printf("[Route] %s -> %s (size: %s > %s)",
|
log.Printf("[Route] %s -> %s (size: %s > %s)",
|
||||||
path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold))
|
path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 记录没有匹配扩展名映射的情况
|
|
||||||
log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase)
|
log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 记录没有扩展名的情况
|
|
||||||
log.Printf("[Route] %s -> %s (no extension)", path, targetBase)
|
log.Printf("[Route] %s -> %s (no extension)", path, targetBase)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 记录没有扩展名映射配置的情况
|
|
||||||
log.Printf("[Route] %s -> %s (no extension map)", path, targetBase)
|
log.Printf("[Route] %s -> %s (no extension map)", path, targetBase)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user