mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 00:21:56 +08:00
- Remove legacy static files, templates, and JavaScript - Update main.go to serve SPA-style web application - Modify admin route handling to support client-side routing - Simplify configuration and metrics API endpoints - Remove server-side template rendering in favor of static file serving - Update Dockerfile and GitHub Actions to build web frontend
573 lines
15 KiB
Go
573 lines
15 KiB
Go
package handler
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"proxy-go/internal/config"
|
||
"proxy-go/internal/metrics"
|
||
"proxy-go/internal/utils"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"golang.org/x/net/http2"
|
||
"golang.org/x/time/rate"
|
||
)
|
||
|
||
const (
|
||
smallBufferSize = 4 * 1024 // 4KB
|
||
mediumBufferSize = 32 * 1024 // 32KB
|
||
largeBufferSize = 64 * 1024 // 64KB
|
||
|
||
// 超时时间常量
|
||
clientConnTimeout = 5 * time.Second // 客户端连接超时
|
||
proxyRespTimeout = 30 * time.Second // 代理响应超时
|
||
backendServTimeout = 20 * time.Second // 后端服务超时
|
||
idleConnTimeout = 120 * time.Second // 空闲连接超时
|
||
tlsHandshakeTimeout = 10 * time.Second // TLS握手超时
|
||
|
||
// 限流相关常量
|
||
globalRateLimit = 1000 // 全局每秒请求数限制
|
||
globalBurstLimit = 200 // 全局突发请求数限制
|
||
perHostRateLimit = 100 // 每个host每秒请求数限制
|
||
perHostBurstLimit = 50 // 每个host突发请求数限制
|
||
perIPRateLimit = 20 // 每个IP每秒请求数限制
|
||
perIPBurstLimit = 10 // 每个IP突发请求数限制
|
||
cleanupInterval = 10 * time.Minute // 清理过期限流器的间隔
|
||
)
|
||
|
||
// 定义不同大小的缓冲池
|
||
var (
|
||
smallBufferPool = sync.Pool{
|
||
New: func() interface{} {
|
||
return bytes.NewBuffer(make([]byte, smallBufferSize))
|
||
},
|
||
}
|
||
|
||
mediumBufferPool = sync.Pool{
|
||
New: func() interface{} {
|
||
return bytes.NewBuffer(make([]byte, mediumBufferSize))
|
||
},
|
||
}
|
||
|
||
largeBufferPool = sync.Pool{
|
||
New: func() interface{} {
|
||
return bytes.NewBuffer(make([]byte, largeBufferSize))
|
||
},
|
||
}
|
||
|
||
// 用于大文件传输的字节切片池
|
||
byteSlicePool = sync.Pool{
|
||
New: func() interface{} {
|
||
b := make([]byte, largeBufferSize)
|
||
return &b
|
||
},
|
||
}
|
||
)
|
||
|
||
// getBuffer 根据大小选择合适的缓冲池
|
||
func getBuffer(size int64) (*bytes.Buffer, func()) {
|
||
var buf *bytes.Buffer
|
||
var pool *sync.Pool
|
||
|
||
switch {
|
||
case size <= smallBufferSize:
|
||
pool = &smallBufferPool
|
||
case size <= mediumBufferSize:
|
||
pool = &mediumBufferPool
|
||
default:
|
||
pool = &largeBufferPool
|
||
}
|
||
|
||
buf = pool.Get().(*bytes.Buffer)
|
||
buf.Reset() // 重置缓冲区
|
||
|
||
return buf, func() {
|
||
if buf != nil {
|
||
pool.Put(buf)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加 hop-by-hop 头部映射
|
||
var hopHeadersMap = make(map[string]bool)
|
||
|
||
func init() {
|
||
headers := []string{
|
||
"Connection",
|
||
"Keep-Alive",
|
||
"Proxy-Authenticate",
|
||
"Proxy-Authorization",
|
||
"Proxy-Connection",
|
||
"Te",
|
||
"Trailer",
|
||
"Transfer-Encoding",
|
||
"Upgrade",
|
||
}
|
||
for _, h := range headers {
|
||
hopHeadersMap[h] = true
|
||
}
|
||
}
|
||
|
||
// ErrorHandler 定义错误处理函数类型
|
||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)
|
||
|
||
// RateLimiter 定义限流器接口
|
||
type RateLimiter interface {
|
||
Allow() bool
|
||
Clean(now time.Time)
|
||
}
|
||
|
||
// 限流管理器
|
||
type rateLimitManager struct {
|
||
globalLimiter *rate.Limiter
|
||
hostLimiters *sync.Map // host -> *rate.Limiter
|
||
ipLimiters *sync.Map // IP -> *rate.Limiter
|
||
lastCleanup time.Time
|
||
}
|
||
|
||
// 创建新的限流管理器
|
||
func newRateLimitManager() *rateLimitManager {
|
||
manager := &rateLimitManager{
|
||
globalLimiter: rate.NewLimiter(rate.Limit(globalRateLimit), globalBurstLimit),
|
||
hostLimiters: &sync.Map{},
|
||
ipLimiters: &sync.Map{},
|
||
lastCleanup: time.Now(),
|
||
}
|
||
|
||
// 启动清理协程
|
||
go manager.cleanupLoop()
|
||
return manager
|
||
}
|
||
|
||
func (m *rateLimitManager) cleanupLoop() {
|
||
ticker := time.NewTicker(cleanupInterval)
|
||
for range ticker.C {
|
||
now := time.Now()
|
||
m.cleanup(now)
|
||
}
|
||
}
|
||
|
||
func (m *rateLimitManager) cleanup(now time.Time) {
|
||
m.hostLimiters.Range(func(key, value interface{}) bool {
|
||
if now.Sub(m.lastCleanup) > cleanupInterval {
|
||
m.hostLimiters.Delete(key)
|
||
}
|
||
return true
|
||
})
|
||
|
||
m.ipLimiters.Range(func(key, value interface{}) bool {
|
||
if now.Sub(m.lastCleanup) > cleanupInterval {
|
||
m.ipLimiters.Delete(key)
|
||
}
|
||
return true
|
||
})
|
||
|
||
m.lastCleanup = now
|
||
}
|
||
|
||
func (m *rateLimitManager) getHostLimiter(host string) *rate.Limiter {
|
||
if limiter, exists := m.hostLimiters.Load(host); exists {
|
||
return limiter.(*rate.Limiter)
|
||
}
|
||
|
||
limiter := rate.NewLimiter(rate.Limit(perHostRateLimit), perHostBurstLimit)
|
||
m.hostLimiters.Store(host, limiter)
|
||
return limiter
|
||
}
|
||
|
||
func (m *rateLimitManager) getIPLimiter(ip string) *rate.Limiter {
|
||
if limiter, exists := m.ipLimiters.Load(ip); exists {
|
||
return limiter.(*rate.Limiter)
|
||
}
|
||
|
||
limiter := rate.NewLimiter(rate.Limit(perIPRateLimit), perIPBurstLimit)
|
||
m.ipLimiters.Store(ip, limiter)
|
||
return limiter
|
||
}
|
||
|
||
// 检查是否允许请求
|
||
func (m *rateLimitManager) allowRequest(r *http.Request) error {
|
||
// 全局限流检查
|
||
if !m.globalLimiter.Allow() {
|
||
return fmt.Errorf("global rate limit exceeded")
|
||
}
|
||
|
||
// Host限流检查
|
||
host := r.Host
|
||
if host != "" {
|
||
if !m.getHostLimiter(host).Allow() {
|
||
return fmt.Errorf("host rate limit exceeded for %s", host)
|
||
}
|
||
}
|
||
|
||
// IP限流检查
|
||
ip := utils.GetClientIP(r)
|
||
if ip != "" {
|
||
if !m.getIPLimiter(ip).Allow() {
|
||
return fmt.Errorf("ip rate limit exceeded for %s", ip)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type ProxyHandler struct {
|
||
pathMap map[string]config.PathConfig
|
||
client *http.Client
|
||
limiter *rate.Limiter
|
||
startTime time.Time
|
||
config *config.Config
|
||
auth *authManager
|
||
errorHandler ErrorHandler // 添加错误处理器
|
||
rateLimiter *rateLimitManager
|
||
}
|
||
|
||
// NewProxyHandler 创建新的代理处理器
|
||
func NewProxyHandler(cfg *config.Config) *ProxyHandler {
|
||
dialer := &net.Dialer{
|
||
Timeout: clientConnTimeout,
|
||
KeepAlive: 30 * time.Second,
|
||
}
|
||
|
||
transport := &http.Transport{
|
||
DialContext: dialer.DialContext,
|
||
MaxIdleConns: 300,
|
||
MaxIdleConnsPerHost: 50,
|
||
IdleConnTimeout: idleConnTimeout,
|
||
TLSHandshakeTimeout: tlsHandshakeTimeout,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
MaxConnsPerHost: 100,
|
||
DisableKeepAlives: false,
|
||
DisableCompression: false,
|
||
ForceAttemptHTTP2: true,
|
||
WriteBufferSize: 64 * 1024,
|
||
ReadBufferSize: 64 * 1024,
|
||
ResponseHeaderTimeout: backendServTimeout,
|
||
MaxResponseHeaderBytes: 64 * 1024,
|
||
}
|
||
|
||
// 设置HTTP/2传输配置
|
||
http2Transport, err := http2.ConfigureTransports(transport)
|
||
if err == nil && http2Transport != nil {
|
||
http2Transport.ReadIdleTimeout = 10 * time.Second
|
||
http2Transport.PingTimeout = 5 * time.Second
|
||
http2Transport.AllowHTTP = false
|
||
http2Transport.MaxReadFrameSize = 32 * 1024
|
||
http2Transport.StrictMaxConcurrentStreams = true
|
||
}
|
||
|
||
handler := &ProxyHandler{
|
||
pathMap: cfg.MAP,
|
||
client: &http.Client{
|
||
Transport: transport,
|
||
Timeout: proxyRespTimeout,
|
||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||
if len(via) >= 10 {
|
||
return fmt.Errorf("stopped after 10 redirects")
|
||
}
|
||
return nil
|
||
},
|
||
},
|
||
limiter: rate.NewLimiter(rate.Limit(5000), 10000),
|
||
startTime: time.Now(),
|
||
config: cfg,
|
||
auth: newAuthManager(),
|
||
rateLimiter: newRateLimitManager(),
|
||
errorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||
log.Printf("[Error] %s %s -> %v", r.Method, r.URL.Path, err)
|
||
if strings.Contains(err.Error(), "rate limit exceeded") {
|
||
w.WriteHeader(http.StatusTooManyRequests)
|
||
w.Write([]byte("Too Many Requests"))
|
||
} else {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
w.Write([]byte("Internal Server Error"))
|
||
}
|
||
},
|
||
}
|
||
|
||
// 注册配置更新回调
|
||
config.RegisterUpdateCallback(func(newCfg *config.Config) {
|
||
handler.pathMap = newCfg.MAP
|
||
handler.config = newCfg
|
||
log.Printf("[Config] 配置已更新并生效")
|
||
})
|
||
|
||
return handler
|
||
}
|
||
|
||
// SetErrorHandler 允许自定义错误处理函数
|
||
func (h *ProxyHandler) SetErrorHandler(handler ErrorHandler) {
|
||
if handler != nil {
|
||
h.errorHandler = handler
|
||
}
|
||
}
|
||
|
||
// copyResponse 使用零拷贝方式传输数据
|
||
func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) {
|
||
buf := byteSlicePool.Get().(*[]byte)
|
||
defer byteSlicePool.Put(buf)
|
||
|
||
written, err := io.CopyBuffer(dst, src, *buf)
|
||
if err == nil && flusher != nil {
|
||
flusher.Flush()
|
||
}
|
||
return written, err
|
||
}
|
||
|
||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
// 添加 panic 恢复
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
log.Printf("[Panic] %s %s -> %v", r.Method, r.URL.Path, err)
|
||
h.errorHandler(w, r, fmt.Errorf("panic: %v", err))
|
||
}
|
||
}()
|
||
|
||
collector := metrics.GetCollector()
|
||
collector.BeginRequest()
|
||
defer collector.EndRequest()
|
||
|
||
// 限流检查
|
||
if err := h.rateLimiter.allowRequest(r); err != nil {
|
||
h.errorHandler(w, r, err)
|
||
return
|
||
}
|
||
|
||
start := time.Now()
|
||
|
||
// 创建带超时的上下文
|
||
ctx, cancel := context.WithTimeout(r.Context(), proxyRespTimeout)
|
||
defer cancel()
|
||
r = r.WithContext(ctx)
|
||
|
||
// 处理根路径请求
|
||
if r.URL.Path == "/" {
|
||
w.WriteHeader(http.StatusOK)
|
||
fmt.Fprint(w, "Welcome to CZL proxy.")
|
||
log.Printf("[%s] %s %s -> %d (root path) [%v]",
|
||
utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(start))
|
||
return
|
||
}
|
||
|
||
// 查找匹配的代理路径
|
||
var matchedPrefix string
|
||
var pathConfig config.PathConfig
|
||
for prefix, cfg := range h.pathMap {
|
||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||
matchedPrefix = prefix
|
||
pathConfig = cfg
|
||
break
|
||
}
|
||
}
|
||
|
||
// 如果没有匹配的路径,返回 404
|
||
if matchedPrefix == "" {
|
||
http.NotFound(w, r)
|
||
log.Printf("[%s] %s %s -> 404 (not found) [%v]",
|
||
utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(start))
|
||
return
|
||
}
|
||
|
||
// 构建目标 URL
|
||
targetPath := strings.TrimPrefix(r.URL.Path, matchedPrefix)
|
||
|
||
// URL 解码,然后重新编码,确保特殊字符被正确处理
|
||
decodedPath, err := url.QueryUnescape(targetPath)
|
||
if err != nil {
|
||
h.errorHandler(w, r, fmt.Errorf("error decoding path: %v", err))
|
||
log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]",
|
||
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start))
|
||
return
|
||
}
|
||
|
||
// 使用统一的路由选择逻辑
|
||
targetBase := utils.GetTargetURL(h.client, r, pathConfig, decodedPath)
|
||
|
||
// 重新编码路径,保留 '/'
|
||
parts := strings.Split(decodedPath, "/")
|
||
for i, part := range parts {
|
||
parts[i] = url.PathEscape(part)
|
||
}
|
||
encodedPath := strings.Join(parts, "/")
|
||
targetURL := targetBase + encodedPath
|
||
|
||
// 解析目标 URL 以获取 host
|
||
parsedURL, err := url.Parse(targetURL)
|
||
if err != nil {
|
||
h.errorHandler(w, r, fmt.Errorf("error parsing URL: %v", err))
|
||
log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]",
|
||
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start))
|
||
return
|
||
}
|
||
|
||
// 创建新的请求时使用带超时的上下文
|
||
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL, r.Body)
|
||
if err != nil {
|
||
h.errorHandler(w, r, fmt.Errorf("error creating request: %v", err))
|
||
return
|
||
}
|
||
|
||
// 添加请求追踪标识
|
||
requestID := utils.GenerateRequestID()
|
||
proxyReq.Header.Set("X-Request-ID", requestID)
|
||
w.Header().Set("X-Request-ID", requestID)
|
||
|
||
// 复制并处理请求头
|
||
copyHeader(proxyReq.Header, r.Header)
|
||
|
||
// 特别处理图片请求
|
||
if utils.IsImageRequest(r.URL.Path) {
|
||
// 获取 Accept 头
|
||
accept := r.Header.Get("Accept")
|
||
|
||
// 根据 Accept 头设置合适的图片格式
|
||
if strings.Contains(accept, "image/avif") {
|
||
proxyReq.Header.Set("Accept", "image/avif")
|
||
} else if strings.Contains(accept, "image/webp") {
|
||
proxyReq.Header.Set("Accept", "image/webp")
|
||
}
|
||
|
||
// 设置 Cloudflare 特定的头部
|
||
proxyReq.Header.Set("CF-Image-Format", "auto") // 让 Cloudflare 根据 Accept 头自动选择格式
|
||
}
|
||
|
||
// 设置其他必要的头部
|
||
proxyReq.Host = parsedURL.Host
|
||
proxyReq.Header.Set("Host", parsedURL.Host)
|
||
proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r))
|
||
proxyReq.Header.Set("X-Forwarded-Host", r.Host)
|
||
proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme)
|
||
|
||
// 添加或更新 X-Forwarded-For
|
||
if clientIP := utils.GetClientIP(r); clientIP != "" {
|
||
if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" {
|
||
proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
|
||
} else {
|
||
proxyReq.Header.Set("X-Forwarded-For", clientIP)
|
||
}
|
||
}
|
||
|
||
// 处理 Cookie 安全属性
|
||
if r.TLS != nil && len(proxyReq.Cookies()) > 0 {
|
||
cookies := proxyReq.Cookies()
|
||
for _, cookie := range cookies {
|
||
if !cookie.Secure {
|
||
cookie.Secure = true
|
||
}
|
||
if !cookie.HttpOnly {
|
||
cookie.HttpOnly = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 发送代理请求
|
||
resp, err := h.client.Do(proxyReq)
|
||
if err != nil {
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
h.errorHandler(w, r, fmt.Errorf("request timeout after %v", proxyRespTimeout))
|
||
log.Printf("[Timeout] %s %s -> timeout after %v",
|
||
r.Method, r.URL.Path, proxyRespTimeout)
|
||
} else {
|
||
h.errorHandler(w, r, fmt.Errorf("proxy error: %v", err))
|
||
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
|
||
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start))
|
||
}
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
copyHeader(w.Header(), resp.Header)
|
||
|
||
// 删除严格的 CSP
|
||
w.Header().Del("Content-Security-Policy")
|
||
|
||
// 根据响应大小选择不同的处理策略
|
||
contentLength := resp.ContentLength
|
||
if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应
|
||
// 获取合适大小的缓冲区
|
||
buf, putBuffer := getBuffer(contentLength)
|
||
defer putBuffer()
|
||
|
||
// 使用缓冲区读取响应
|
||
_, err := io.Copy(buf, resp.Body)
|
||
if err != nil {
|
||
h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err))
|
||
return
|
||
}
|
||
|
||
// 设置响应状态码并一次性写入响应
|
||
w.WriteHeader(resp.StatusCode)
|
||
written, err := w.Write(buf.Bytes())
|
||
if err != nil {
|
||
if !isConnectionClosed(err) {
|
||
log.Printf("Error writing response: %v", err)
|
||
}
|
||
}
|
||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), utils.GetClientIP(r), r)
|
||
} else {
|
||
// 大响应使用零拷贝传输
|
||
w.WriteHeader(resp.StatusCode)
|
||
var bytesCopied int64
|
||
var err error
|
||
|
||
if f, ok := w.(http.Flusher); ok {
|
||
bytesCopied, err = copyResponse(w, resp.Body, f)
|
||
} else {
|
||
bytesCopied, err = copyResponse(w, resp.Body, nil)
|
||
}
|
||
|
||
if err != nil && !isConnectionClosed(err) {
|
||
log.Printf("Error copying response: %v", err)
|
||
}
|
||
|
||
// 记录访问日志
|
||
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %-50s -> %s",
|
||
r.Method, // HTTP方法,左对齐占6位
|
||
resp.StatusCode, // 状态码,占3位
|
||
time.Since(start), // 处理时间,占12位
|
||
utils.GetClientIP(r), // IP地址,占15位
|
||
utils.FormatBytes(bytesCopied), // 传输大小,占10位
|
||
utils.GetRequestSource(r), // 请求来源
|
||
r.URL.Path, // 请求路径,左对齐占50位
|
||
targetURL, // 目标URL
|
||
)
|
||
|
||
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), bytesCopied, utils.GetClientIP(r), r)
|
||
}
|
||
}
|
||
|
||
func copyHeader(dst, src http.Header) {
|
||
// 处理 Connection 头部指定的其他 hop-by-hop 头部
|
||
if connection := src.Get("Connection"); connection != "" {
|
||
for _, h := range strings.Split(connection, ",") {
|
||
hopHeadersMap[strings.TrimSpace(h)] = true
|
||
}
|
||
}
|
||
|
||
// 使用 map 快速查找,跳过 hop-by-hop 头部
|
||
for k, vv := range src {
|
||
if !hopHeadersMap[k] {
|
||
for _, v := range vv {
|
||
dst.Add(k, v)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加辅助函数判断是否是连接关闭错误
|
||
func isConnectionClosed(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
return strings.Contains(err.Error(), "broken pipe") ||
|
||
strings.Contains(err.Error(), "connection reset by peer") ||
|
||
strings.Contains(err.Error(), "protocol wrong type for socket")
|
||
}
|