diff --git a/Dockerfile b/Dockerfile index b02fa35..ad3f4b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,6 @@ RUN mkdir -p /app/data && \ chmod +x /app/proxy-go && \ apk add --no-cache ca-certificates tzdata -EXPOSE 80 +EXPOSE 3336 VOLUME ["/app/data"] ENTRYPOINT ["/app/proxy-go"] diff --git a/docker-compose.yml b/docker-compose.yml index a8de295..1efe3ad 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ services: image: woodchen/proxy-go:latest container_name: proxy-go ports: - - "3334:80" + - "3336:3336" volumes: - ./data:/app/data environment: @@ -18,7 +18,7 @@ services: cpus: '0.25' memory: 128M healthcheck: - test: ["CMD", "wget", "-q", "--spider", "http://localhost:80/"] + test: ["CMD", "wget", "-q", "--spider", "http://localhost:3336/"] interval: 30s timeout: 3s retries: 3 \ No newline at end of file diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 0c751c0..853fa1f 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -3,6 +3,9 @@ package handler import ( "crypto/rand" "encoding/base64" + "encoding/json" + "net/http" + "strings" "sync" "time" ) @@ -60,3 +63,51 @@ func (am *authManager) cleanExpiredTokens() { }) } } + +// AuthMiddleware 认证中间件 +func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" || !strings.HasPrefix(auth, "Bearer ") { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + token := strings.TrimPrefix(auth, "Bearer ") + if !h.auth.validateToken(token) { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + next(w, r) + } +} + +// AuthHandler 处理认证请求 +func (h *ProxyHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + if req.Password != h.config.Metrics.Password { + http.Error(w, "Invalid password", http.StatusUnauthorized) + return + } + + token := h.auth.generateToken() + h.auth.addToken(token, time.Duration(h.config.Metrics.TokenExpiry)*time.Second) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": token, + }) +} diff --git a/internal/handler/config.go b/internal/handler/config.go new file mode 100644 index 0000000..55a55d5 --- /dev/null +++ b/internal/handler/config.go @@ -0,0 +1,141 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "proxy-go/internal/config" +) + +// ConfigHandler 配置管理处理器 +type ConfigHandler struct { + config *config.Config +} + +// NewConfigHandler 创建新的配置管理处理器 +func NewConfigHandler(cfg *config.Config) *ConfigHandler { + return &ConfigHandler{ + config: cfg, + } +} + +// ServeHTTP 实现http.Handler接口 +func (h *ConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/metrics/config": + h.handleConfigPage(w, r) + case "/metrics/config/get": + h.handleGetConfig(w, r) + case "/metrics/config/save": + h.handleSaveConfig(w, r) + default: + http.NotFound(w, r) + } +} + +// handleConfigPage 处理配置页面请求 +func (h *ConfigHandler) handleConfigPage(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, "web/templates/config.html") +} + +// handleGetConfig 处理获取配置请求 +func (h *ConfigHandler) handleGetConfig(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // 读取当前配置文件 + configData, err := os.ReadFile("data/config.json") + if err != nil { + http.Error(w, fmt.Sprintf("读取配置文件失败: %v", err), http.StatusInternalServerError) + return + } + + w.Write(configData) +} + +// handleSaveConfig 处理保存配置请求 +func (h *ConfigHandler) handleSaveConfig(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "方法不允许", http.StatusMethodNotAllowed) + return + } + + // 解析新配置 + var newConfig config.Config + if err := json.NewDecoder(r.Body).Decode(&newConfig); err != nil { + http.Error(w, fmt.Sprintf("解析配置失败: %v", err), http.StatusBadRequest) + return + } + + // 验证新配置 + if err := h.validateConfig(&newConfig); err != nil { + http.Error(w, fmt.Sprintf("配置验证失败: %v", err), http.StatusBadRequest) + return + } + + // 将新配置格式化为JSON + configData, err := json.MarshalIndent(newConfig, "", " ") + if err != nil { + http.Error(w, fmt.Sprintf("格式化配置失败: %v", err), http.StatusInternalServerError) + return + } + + // 保存到临时文件 + tempFile := "data/config.json.tmp" + if err := os.WriteFile(tempFile, configData, 0644); err != nil { + http.Error(w, fmt.Sprintf("保存配置失败: %v", err), http.StatusInternalServerError) + return + } + + // 重命名临时文件为正式文件 + if err := os.Rename(tempFile, "data/config.json"); err != nil { + http.Error(w, fmt.Sprintf("更新配置文件失败: %v", err), http.StatusInternalServerError) + return + } + + // 更新运行时配置 + *h.config = newConfig + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "配置已更新并生效"}`)) +} + +// validateConfig 验证配置 +func (h *ConfigHandler) validateConfig(cfg *config.Config) error { + if cfg == nil { + return fmt.Errorf("配置不能为空") + } + + // 验证MAP配置 + if cfg.MAP == nil { + return fmt.Errorf("MAP配置不能为空") + } + + for path, pathConfig := range cfg.MAP { + if path == "" { + return fmt.Errorf("路径不能为空") + } + if pathConfig.DefaultTarget == "" { + return fmt.Errorf("路径 %s 的默认目标不能为空", path) + } + if _, err := url.Parse(pathConfig.DefaultTarget); err != nil { + return fmt.Errorf("路径 %s 的默认目标URL无效: %v", path, err) + } + } + + // 验证FixedPaths配置 + for _, fp := range cfg.FixedPaths { + if fp.Path == "" { + return fmt.Errorf("固定路径不能为空") + } + if fp.TargetURL == "" { + return fmt.Errorf("固定路径 %s 的目标URL不能为空", fp.Path) + } + if _, err := url.Parse(fp.TargetURL); err != nil { + return fmt.Errorf("固定路径 %s 的目标URL无效: %v", fp.Path, err) + } + } + + return nil +} diff --git a/internal/handler/metrics.go b/internal/handler/metrics.go index b38ec03..c929437 100644 --- a/internal/handler/metrics.go +++ b/internal/handler/metrics.go @@ -6,8 +6,8 @@ import ( "net/http" "proxy-go/internal/metrics" "proxy-go/internal/models" + "proxy-go/internal/utils" "runtime" - "strings" "time" ) @@ -43,7 +43,6 @@ func (h *ProxyHandler) MetricsHandler(w http.ResponseWriter, r *http.Request) { stats := collector.GetStats() if stats == nil { - // 返回默认值而不是错误 stats = map[string]interface{}{ "uptime": uptime.String(), "active_requests": int64(0), @@ -64,23 +63,27 @@ func (h *ProxyHandler) MetricsHandler(w http.ResponseWriter, r *http.Request) { } } - // 确保所有必要的字段都存在 + totalRequests := utils.SafeInt64(stats["total_requests"]) + totalErrors := utils.SafeInt64(stats["total_errors"]) + totalBytes := utils.SafeInt64(stats["total_bytes"]) + uptimeSeconds := uptime.Seconds() + metrics := Metrics{ Uptime: uptime.String(), - ActiveRequests: safeInt64(stats["active_requests"]), - TotalRequests: safeInt64(stats["total_requests"]), - TotalErrors: safeInt64(stats["total_errors"]), - ErrorRate: float64(safeInt64(stats["total_errors"])) / float64(max(safeInt64(stats["total_requests"]), 1)), - NumGoroutine: safeInt(stats["num_goroutine"]), - MemoryUsage: safeString(stats["memory_usage"]), - AverageResponseTime: safeString(stats["avg_response_time"]), - TotalBytes: safeInt64(stats["total_bytes"]), - BytesPerSecond: float64(safeInt64(stats["total_bytes"])) / metrics.Max(uptime.Seconds(), 1), - RequestsPerSecond: float64(safeInt64(stats["total_requests"])) / metrics.Max(uptime.Seconds(), 1), - StatusCodeStats: safeStatusCodeStats(stats["status_code_stats"]), - TopPaths: safePathMetrics(stats["top_paths"]), - RecentRequests: safeRequestLogs(stats["recent_requests"]), - TopReferers: safePathMetrics(stats["top_referers"]), + ActiveRequests: utils.SafeInt64(stats["active_requests"]), + TotalRequests: totalRequests, + TotalErrors: totalErrors, + ErrorRate: float64(totalErrors) / float64(utils.Max(totalRequests, 1)), + NumGoroutine: utils.SafeInt(stats["num_goroutine"]), + MemoryUsage: utils.SafeString(stats["memory_usage"], "0 B"), + AverageResponseTime: utils.SafeString(stats["avg_response_time"], "0 ms"), + TotalBytes: totalBytes, + BytesPerSecond: float64(totalBytes) / utils.MaxFloat64(uptimeSeconds, 1), + RequestsPerSecond: float64(totalRequests) / utils.MaxFloat64(uptimeSeconds, 1), + StatusCodeStats: models.SafeStatusCodeStats(stats["status_code_stats"]), + TopPaths: models.SafePathMetrics(stats["top_paths"]), + RecentRequests: models.SafeRequestLogs(stats["recent_requests"]), + TopReferers: models.SafePathMetrics(stats["top_referers"]), } w.Header().Set("Content-Type", "application/json") @@ -88,765 +91,3 @@ func (h *ProxyHandler) MetricsHandler(w http.ResponseWriter, r *http.Request) { log.Printf("Error encoding metrics: %v", err) } } - -// 辅助函数 -func max(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// 修改模板,添加登录页面 -var loginTemplate = ` - - - - Proxy-Go Metrics Login - - - - -
-

Metrics Login

-
密码错误
-
- -
- -
- - - - -` - -// 修改原有的 metricsTemplate,添加 token 检查 -var metricsTemplate = ` - - - - Proxy-Go Metrics - - - - -

Proxy-Go Metrics

- -
-
-

基础指标

-
- 运行时间 - -
-
- 当前活跃请求 - -
-
- 总请求数 - -
-
- 错误数 - -
-
- 错误率 - -
-
- -
-

系统指标

-
- Goroutine数量 - -
-
- 内存使用 - -
-
- -
-

性能指标

-
- 平均响应时间 - -
-
- 每秒请求数 - -
-
- -
-

流量统计

-
- 总传输字节 - -
-
- 每传输 - -
-
-
- -
-

状态码统计

-
-
- -
-

热门路径 (Top 10)

- - - - - - - - - - - -
路径请求数错误数平均延迟传输大小
-
- -
-

最近请求

- - - - - - - - - - - - -
时间路径状态延迟大小客户端IP
-
- -
-

热门引用来源 (Top 10)

- - - - - - - - -
来源请求数
-
- - - - -
- - - - -` - -// 添加认证中间件 -func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if auth == "" || !strings.HasPrefix(auth, "Bearer ") { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - token := strings.TrimPrefix(auth, "Bearer ") - if !h.auth.validateToken(token) { - http.Error(w, "Invalid token", http.StatusUnauthorized) - return - } - - next(w, r) - } -} - -// 修改处理器 -func (h *ProxyHandler) MetricsPageHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write([]byte(loginTemplate)) -} - -func (h *ProxyHandler) MetricsDashboardHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write([]byte(metricsTemplate)) -} - -func (h *ProxyHandler) MetricsAuthHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - var req struct { - Password string `json:"password"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - if req.Password != h.config.Metrics.Password { - http.Error(w, "Invalid password", http.StatusUnauthorized) - return - } - - token := h.auth.generateToken() - h.auth.addToken(token, time.Duration(h.config.Metrics.TokenExpiry)*time.Second) - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]string{ - "token": token, - }) -} - -// 添加安全的类型转换辅助函数 -func safeStatusCodeStats(v interface{}) map[string]int64 { - if v == nil { - return make(map[string]int64) - } - if m, ok := v.(map[string]int64); ok { - return m - } - return make(map[string]int64) -} - -func safePathMetrics(v interface{}) []models.PathMetrics { - if v == nil { - return []models.PathMetrics{} - } - if m, ok := v.([]models.PathMetrics); ok { - return m - } - return []models.PathMetrics{} -} - -func safeRequestLogs(v interface{}) []models.RequestLog { - if v == nil { - return []models.RequestLog{} - } - if m, ok := v.([]models.RequestLog); ok { - return m - } - return []models.RequestLog{} -} - -func safeInt64(v interface{}) int64 { - if v == nil { - return 0 - } - if i, ok := v.(int64); ok { - return i - } - return 0 -} - -func safeInt(v interface{}) int { - if v == nil { - return 0 - } - if i, ok := v.(int); ok { - return i - } - return 0 -} - -func safeString(v interface{}) string { - if v == nil { - return "0 B" // 返回默认值 - } - if s, ok := v.(string); ok { - return s - } - return "0 B" // 返回默认值 -} diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index e822c82..12f17d4 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -1,9 +1,12 @@ package handler import ( + "bytes" + "context" "fmt" "io" "log" + "net" "net/http" "net/url" "proxy-go/internal/config" @@ -17,58 +20,310 @@ import ( ) const ( - defaultBufferSize = 32 * 1024 // 32KB + smallBufferSize = 4 * 1024 // 4KB + mediumBufferSize = 32 * 1024 // 32KB + largeBufferSize = 64 * 1024 // 64KB + + // 超时时间常量 + clientConnTimeout = 3 * time.Second // 客户端连接超时 + proxyRespTimeout = 10 * time.Second // 代理响应超时 + backendServTimeout = 8 * time.Second // 后端服务超时 + idleConnTimeout = 120 * time.Second // 空闲连接超时 + tlsHandshakeTimeout = 5 * time.Second // TLS握手超时 + + // 限流相关常量 + globalRateLimit = 1000 // 全局每秒请求数限制 + globalBurstLimit = 200 // 全局突发请求数限制 + perHostRateLimit = 100 // 每个host每秒请求数限制 + perHostBurstLimit = 50 // 每个host突发请求数限制 + perIPRateLimit = 20 // 每个IP每秒请求数限制 + perIPBurstLimit = 10 // 每个IP突发请求数限制 + cleanupInterval = 10 * time.Minute // 清理过期限流器的间隔 ) -var bufferPool = sync.Pool{ - New: func() interface{} { - buf := make([]byte, defaultBufferSize) - return &buf - }, +// 定义不同大小的缓冲池 +var ( + smallBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, smallBufferSize)) + }, + } + + mediumBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, mediumBufferSize)) + }, + } + + largeBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, largeBufferSize)) + }, + } + + // 用于大文件传输的字节切片池 + byteSlicePool = sync.Pool{ + New: func() interface{} { + b := make([]byte, largeBufferSize) + return &b + }, + } +) + +// getBuffer 根据大小选择合适的缓冲池 +func getBuffer(size int64) (*bytes.Buffer, func()) { + var buf *bytes.Buffer + var pool *sync.Pool + + switch { + case size <= smallBufferSize: + pool = &smallBufferPool + case size <= mediumBufferSize: + pool = &mediumBufferPool + default: + pool = &largeBufferPool + } + + buf = pool.Get().(*bytes.Buffer) + buf.Reset() // 重置缓冲区 + + return buf, func() { + if buf != nil { + pool.Put(buf) + } + } +} + +// 添加 hop-by-hop 头部映射 +var hopHeadersMap = make(map[string]bool) + +func init() { + headers := []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Proxy-Connection", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", + } + for _, h := range headers { + hopHeadersMap[h] = true + } +} + +// ErrorHandler 定义错误处理函数类型 +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) + +// RateLimiter 定义限流器接口 +type RateLimiter interface { + Allow() bool + Clean(now time.Time) +} + +// 限流管理器 +type rateLimitManager struct { + globalLimiter *rate.Limiter + hostLimiters *sync.Map // host -> *rate.Limiter + ipLimiters *sync.Map // IP -> *rate.Limiter + lastCleanup time.Time +} + +// 创建新的限流管理器 +func newRateLimitManager() *rateLimitManager { + manager := &rateLimitManager{ + globalLimiter: rate.NewLimiter(rate.Limit(globalRateLimit), globalBurstLimit), + hostLimiters: &sync.Map{}, + ipLimiters: &sync.Map{}, + lastCleanup: time.Now(), + } + + // 启动清理协程 + go manager.cleanupLoop() + return manager +} + +func (m *rateLimitManager) cleanupLoop() { + ticker := time.NewTicker(cleanupInterval) + for range ticker.C { + now := time.Now() + m.cleanup(now) + } +} + +func (m *rateLimitManager) cleanup(now time.Time) { + m.hostLimiters.Range(func(key, value interface{}) bool { + if now.Sub(m.lastCleanup) > cleanupInterval { + m.hostLimiters.Delete(key) + } + return true + }) + + m.ipLimiters.Range(func(key, value interface{}) bool { + if now.Sub(m.lastCleanup) > cleanupInterval { + m.ipLimiters.Delete(key) + } + return true + }) + + m.lastCleanup = now +} + +func (m *rateLimitManager) getHostLimiter(host string) *rate.Limiter { + if limiter, exists := m.hostLimiters.Load(host); exists { + return limiter.(*rate.Limiter) + } + + limiter := rate.NewLimiter(rate.Limit(perHostRateLimit), perHostBurstLimit) + m.hostLimiters.Store(host, limiter) + return limiter +} + +func (m *rateLimitManager) getIPLimiter(ip string) *rate.Limiter { + if limiter, exists := m.ipLimiters.Load(ip); exists { + return limiter.(*rate.Limiter) + } + + limiter := rate.NewLimiter(rate.Limit(perIPRateLimit), perIPBurstLimit) + m.ipLimiters.Store(ip, limiter) + return limiter +} + +// 检查是否允许请求 +func (m *rateLimitManager) allowRequest(r *http.Request) error { + // 全局限流检查 + if !m.globalLimiter.Allow() { + return fmt.Errorf("global rate limit exceeded") + } + + // Host限流检查 + host := r.Host + if host != "" { + if !m.getHostLimiter(host).Allow() { + return fmt.Errorf("host rate limit exceeded for %s", host) + } + } + + // IP限流检查 + ip := utils.GetClientIP(r) + if ip != "" { + if !m.getIPLimiter(ip).Allow() { + return fmt.Errorf("ip rate limit exceeded for %s", ip) + } + } + + return nil } type ProxyHandler struct { - pathMap map[string]config.PathConfig - client *http.Client - limiter *rate.Limiter - startTime time.Time - config *config.Config - auth *authManager + pathMap map[string]config.PathConfig + client *http.Client + limiter *rate.Limiter + startTime time.Time + config *config.Config + auth *authManager + errorHandler ErrorHandler // 添加错误处理器 + rateLimiter *rateLimitManager } // 修改参数类型 func NewProxyHandler(cfg *config.Config) *ProxyHandler { + dialer := &net.Dialer{ + Timeout: clientConnTimeout, // 客户端连接超时 + KeepAlive: 30 * time.Second, // TCP keepalive 间隔 + } + transport := &http.Transport{ - MaxIdleConns: 100, // 最大空闲连接数 - MaxIdleConnsPerHost: 10, // 每个 host 的最大空闲连接数 - IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间 + DialContext: dialer.DialContext, + MaxIdleConns: 200, + MaxIdleConnsPerHost: 20, + IdleConnTimeout: idleConnTimeout, // 空闲连接超时 + TLSHandshakeTimeout: tlsHandshakeTimeout, // TLS握手超时 + ExpectContinueTimeout: 1 * time.Second, + MaxConnsPerHost: 50, + DisableKeepAlives: false, + DisableCompression: false, + ForceAttemptHTTP2: true, + WriteBufferSize: 64 * 1024, + ReadBufferSize: 64 * 1024, + ResponseHeaderTimeout: backendServTimeout, // 后端服务响应头超时 } return &ProxyHandler{ pathMap: cfg.MAP, client: &http.Client{ Transport: transport, - Timeout: 30 * time.Second, + Timeout: proxyRespTimeout, // 整体代理响应超时 + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + return nil + }, + }, + limiter: rate.NewLimiter(rate.Limit(5000), 10000), + startTime: time.Now(), + config: cfg, + auth: newAuthManager(), + rateLimiter: newRateLimitManager(), + errorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("[Error] %s %s -> %v", r.Method, r.URL.Path, err) + if strings.Contains(err.Error(), "rate limit exceeded") { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + } else { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } }, - limiter: rate.NewLimiter(rate.Limit(5000), 10000), - startTime: time.Now(), - config: cfg, - auth: newAuthManager(), } } +// SetErrorHandler 允许自定义错误处理函数 +func (h *ProxyHandler) SetErrorHandler(handler ErrorHandler) { + if handler != nil { + h.errorHandler = handler + } +} + +// copyResponse 使用零拷贝方式传输数据 +func copyResponse(dst io.Writer, src io.Reader, flusher http.Flusher) (int64, error) { + buf := byteSlicePool.Get().(*[]byte) + defer byteSlicePool.Put(buf) + + written, err := io.CopyBuffer(dst, src, *buf) + if err == nil && flusher != nil { + flusher.Flush() + } + return written, err +} + func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // 添加 panic 恢复 + defer func() { + if err := recover(); err != nil { + log.Printf("[Panic] %s %s -> %v", r.Method, r.URL.Path, err) + h.errorHandler(w, r, fmt.Errorf("panic: %v", err)) + } + }() + collector := metrics.GetCollector() collector.BeginRequest() defer collector.EndRequest() - if !h.limiter.Allow() { - http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + // 限流检查 + if err := h.rateLimiter.allowRequest(r); err != nil { + h.errorHandler(w, r, err) return } start := time.Now() + // 创建带超时的上下文 + ctx, cancel := context.WithTimeout(r.Context(), proxyRespTimeout) + defer cancel() + r = r.WithContext(ctx) + // 处理根路径请求 if r.URL.Path == "/" { w.WriteHeader(http.StatusOK) @@ -103,7 +358,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // URL 解码,然后重新编码,确保特殊字符被正确处理 decodedPath, err := url.QueryUnescape(targetPath) if err != nil { - http.Error(w, "Error decoding path", http.StatusInternalServerError) + h.errorHandler(w, r, fmt.Errorf("error decoding path: %v", err)) log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]", utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) return @@ -123,41 +378,27 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 解析目标 URL 以获取 host parsedURL, err := url.Parse(targetURL) if err != nil { - http.Error(w, "Error parsing target URL", http.StatusInternalServerError) + h.errorHandler(w, r, fmt.Errorf("error parsing URL: %v", err)) log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]", utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) return } - // 创建新的请求 - proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) + // 创建新的请求时使用带超时的上下文 + proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL, r.Body) if err != nil { - http.Error(w, "Error creating proxy request", http.StatusInternalServerError) + h.errorHandler(w, r, fmt.Errorf("error creating request: %v", err)) return } - // 复制原始请求头 + // 添加请求追踪标识 + requestID := utils.GenerateRequestID() + proxyReq.Header.Set("X-Request-ID", requestID) + w.Header().Set("X-Request-ID", requestID) + + // 复制并处理请求头 copyHeader(proxyReq.Header, r.Header) - // 特别处理图片请求 - // if utils.IsImageRequest(r.URL.Path) { - // // 设置优化的 Accept 头 - // accept := r.Header.Get("Accept") - // if accept != "" { - // proxyReq.Header.Set("Accept", accept) - // } else { - // proxyReq.Header.Set("Accept", "image/avif,image/webp,image/jpeg,image/png,*/*;q=0.8") - // } - - // // 设置 Cloudflare 特定的头部 - // proxyReq.Header.Set("CF-Accept-Content", "image/avif,image/webp") - // proxyReq.Header.Set("CF-Optimize-Images", "on") - - // // 删除可能影响缓存的头部 - // proxyReq.Header.Del("If-None-Match") - // proxyReq.Header.Del("If-Modified-Since") - // proxyReq.Header.Set("Cache-Control", "no-cache") - // } // 特别处理图片请求 if utils.IsImageRequest(r.URL.Path) { // 获取 Accept 头 @@ -190,13 +431,31 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // 处理 Cookie 安全属性 + if r.TLS != nil && len(proxyReq.Cookies()) > 0 { + cookies := proxyReq.Cookies() + for _, cookie := range cookies { + if !cookie.Secure { + cookie.Secure = true + } + if !cookie.HttpOnly { + cookie.HttpOnly = true + } + } + } + // 发送代理请求 resp, err := h.client.Do(proxyReq) - if err != nil { - http.Error(w, "Error forwarding request", http.StatusBadGateway) - log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) + if ctx.Err() == context.DeadlineExceeded { + h.errorHandler(w, r, fmt.Errorf("request timeout after %v", proxyRespTimeout)) + log.Printf("[Timeout] %s %s -> timeout after %v", + r.Method, r.URL.Path, proxyRespTimeout) + } else { + h.errorHandler(w, r, fmt.Errorf("proxy error: %v", err)) + log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) + } return } defer resp.Body.Close() @@ -212,13 +471,19 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 根据响应大小选择不同的处理策略 contentLength := resp.ContentLength if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应 - // 直接读取到内存并一次性写入 - body, err := io.ReadAll(resp.Body) + // 获取合适大小的缓冲区 + buf, putBuffer := getBuffer(contentLength) + defer putBuffer() + + // 使用缓冲区读取响应 + _, err := io.Copy(buf, resp.Body) if err != nil { - http.Error(w, "Error reading response", http.StatusInternalServerError) + h.errorHandler(w, r, fmt.Errorf("error reading response: %v", err)) return } - written, err := w.Write(body) + + // 一次性写入响应 + written, err := w.Write(buf.Bytes()) if err != nil { if !isConnectionClosed(err) { log.Printf("Error writing response: %v", err) @@ -226,38 +491,18 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), utils.GetClientIP(r), r) } else { - // 大响应使用流式传输 + // 大响应使用零拷贝传输 var bytesCopied int64 - if f, ok := w.(http.Flusher); ok { - bufPtr := bufferPool.Get().(*[]byte) - defer bufferPool.Put(bufPtr) - buf := *bufPtr + var err error - for { - n, rerr := resp.Body.Read(buf) - if n > 0 { - bytesCopied += int64(n) - _, werr := w.Write(buf[:n]) - if werr != nil { - log.Printf("Error writing response: %v", werr) - return - } - f.Flush() - } - if rerr == io.EOF { - break - } - if rerr != nil { - log.Printf("Error reading response: %v", rerr) - break - } - } + if f, ok := w.(http.Flusher); ok { + bytesCopied, err = copyResponse(w, resp.Body, f) } else { - // 如果不支持 Flusher,使用普通的 io.Copy - bytesCopied, err = io.Copy(w, resp.Body) - if err != nil { - log.Printf("Error copying response: %v", err) - } + bytesCopied, err = copyResponse(w, resp.Body, nil) + } + + if err != nil && !isConnectionClosed(err) { + log.Printf("Error copying response: %v", err) } // 记录访问日志 @@ -277,9 +522,19 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func copyHeader(dst, src http.Header) { + // 处理 Connection 头部指定的其他 hop-by-hop 头部 + if connection := src.Get("Connection"); connection != "" { + for _, h := range strings.Split(connection, ",") { + hopHeadersMap[strings.TrimSpace(h)] = true + } + } + + // 使用 map 快速查找,跳过 hop-by-hop 头部 for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) + if !hopHeadersMap[k] { + for _, v := range vv { + dst.Add(k, v) + } } } } diff --git a/internal/models/utils.go b/internal/models/utils.go new file mode 100644 index 0000000..7d0129b --- /dev/null +++ b/internal/models/utils.go @@ -0,0 +1,34 @@ +package models + +// SafeStatusCodeStats 安全地将 interface{} 转换为状态码统计 +func SafeStatusCodeStats(v interface{}) map[string]int64 { + if v == nil { + return make(map[string]int64) + } + if m, ok := v.(map[string]int64); ok { + return m + } + return make(map[string]int64) +} + +// SafePathMetrics 安全地将 interface{} 转换为路径指标 +func SafePathMetrics(v interface{}) []PathMetrics { + if v == nil { + return []PathMetrics{} + } + if m, ok := v.([]PathMetrics); ok { + return m + } + return []PathMetrics{} +} + +// SafeRequestLogs 安全地将 interface{} 转换为请求日志 +func SafeRequestLogs(v interface{}) []RequestLog { + if v == nil { + return []RequestLog{} + } + if m, ok := v.([]RequestLog); ok { + return m + } + return []RequestLog{} +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 2161caa..d0fecf8 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,6 +2,8 @@ package utils import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "log" "net" @@ -61,6 +63,16 @@ func init() { }() } +// GenerateRequestID 生成唯一的请求ID +func GenerateRequestID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // 如果随机数生成失败,使用时间戳作为备选 + return fmt.Sprintf("%x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} + func GetClientIP(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip @@ -223,3 +235,52 @@ func isCacheHit(url string) bool { } return false } + +// SafeInt64 安全地将 interface{} 转换为 int64 +func SafeInt64(v interface{}) int64 { + if v == nil { + return 0 + } + if i, ok := v.(int64); ok { + return i + } + return 0 +} + +// SafeInt 安全地将 interface{} 转换为 int +func SafeInt(v interface{}) int { + if v == nil { + return 0 + } + if i, ok := v.(int); ok { + return i + } + return 0 +} + +// SafeString 安全地将 interface{} 转换为 string +func SafeString(v interface{}, defaultValue string) string { + if v == nil { + return defaultValue + } + if s, ok := v.(string); ok { + return s + } + return defaultValue +} + +// Max 返回两个 int64 中的较大值 +func Max(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// MaxFloat64 返回两个 float64 中的较大值 +func MaxFloat64(a, b float64) float64 { + if a > b { + return a + } + return b +} diff --git a/main.go b/main.go index 2a045e3..7943cf4 100644 --- a/main.go +++ b/main.go @@ -75,22 +75,36 @@ func main() { // 创建主处理器 mainHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 先处理监控路由 - switch r.URL.Path { - case "/metrics": - proxyHandler.AuthMiddleware(proxyHandler.MetricsHandler)(w, r) - return - case "/metrics/ui": - proxyHandler.MetricsPageHandler(w, r) - return - case "/metrics/auth": - proxyHandler.MetricsAuthHandler(w, r) - return - case "/metrics/dashboard": - proxyHandler.MetricsDashboardHandler(w, r) + // 处理静态文件 + if strings.HasPrefix(r.URL.Path, "/web/static/") { + http.StripPrefix("/web/static/", http.FileServer(http.Dir("web/static"))).ServeHTTP(w, r) return } + // 处理管理路由 + if strings.HasPrefix(r.URL.Path, "/admin/") { + switch r.URL.Path { + case "/admin/login": + http.ServeFile(w, r, "web/templates/admin/login.html") + return + case "/admin/metrics": + proxyHandler.AuthMiddleware(proxyHandler.MetricsHandler)(w, r) + return + case "/admin/config": + proxyHandler.AuthMiddleware(handler.NewConfigHandler(cfg).ServeHTTP)(w, r) + return + case "/admin/config/get": + proxyHandler.AuthMiddleware(handler.NewConfigHandler(cfg).ServeHTTP)(w, r) + return + case "/admin/config/save": + proxyHandler.AuthMiddleware(handler.NewConfigHandler(cfg).ServeHTTP)(w, r) + return + case "/admin/auth": + proxyHandler.AuthHandler(w, r) + return + } + } + // 遍历所有处理器 for _, h := range handlers { if h.matcher(r) { @@ -108,7 +122,7 @@ func main() { // 创建服务器 server := &http.Server{ - Addr: ":80", + Addr: ":3336", Handler: handler, } @@ -124,7 +138,7 @@ func main() { }() // 启动服务器 - log.Println("Starting proxy server on :80") + log.Println("Starting proxy server on :3336") if err := server.ListenAndServe(); err != http.ErrServerClosed { log.Fatal("Error starting server:", err) } diff --git a/web/static/css/main.css b/web/static/css/main.css new file mode 100644 index 0000000..b7ba486 --- /dev/null +++ b/web/static/css/main.css @@ -0,0 +1,72 @@ +body { + font-family: Arial, sans-serif; + margin: 20px; + background-color: #f5f5f5; +} + +.container { + max-width: 1200px; + margin: 0 auto; + background-color: white; + padding: 20px; + border-radius: 8px; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +h1 { + color: #333; + margin-bottom: 20px; +} + +#editor { + width: 100%; + height: 600px; + margin-bottom: 20px; + border: 1px solid #ddd; + border-radius: 4px; +} + +.button-group { + margin-bottom: 20px; +} + +button { + background-color: #4CAF50; + color: white; + padding: 10px 20px; + border: none; + border-radius: 4px; + cursor: pointer; + margin-right: 10px; +} + +button:hover { + background-color: #45a049; +} + +button.secondary { + background-color: #008CBA; +} + +button.secondary:hover { + background-color: #007B9E; +} + +#message { + padding: 10px; + margin-top: 10px; + border-radius: 4px; + display: none; +} + +.success { + background-color: #dff0d8; + color: #3c763d; + border: 1px solid #d6e9c6; +} + +.error { + background-color: #f2dede; + color: #a94442; + border: 1px solid #ebccd1; +} \ No newline at end of file diff --git a/web/static/js/auth.js b/web/static/js/auth.js new file mode 100644 index 0000000..ec8bbd1 --- /dev/null +++ b/web/static/js/auth.js @@ -0,0 +1,61 @@ +// 检查认证状态 +function checkAuth() { + const token = localStorage.getItem('token'); + if (!token) { + window.location.href = '/admin/login'; + return false; + } + return true; +} + +// 登录函数 +async function login() { + const password = document.getElementById('password').value; + + try { + const response = await fetch('/admin/auth', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ password }) + }); + + if (!response.ok) { + throw new Error('登录失败'); + } + + const data = await response.json(); + localStorage.setItem('token', data.token); + window.location.href = '/admin/metrics'; + } catch (error) { + showToast(error.message, true); + } +} + +// 退出登录 +function logout() { + localStorage.removeItem('token'); + window.location.href = '/admin/login'; +} + +// 获取认证头 +function getAuthHeaders() { + const token = localStorage.getItem('token'); + return { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json' + }; +} + +// 显示提示消息 +function showToast(message, isError = false) { + const toast = document.createElement('div'); + toast.className = `toast toast-end ${isError ? 'alert alert-error' : 'alert alert-success'}`; + toast.innerHTML = `${message}`; + document.body.appendChild(toast); + + setTimeout(() => { + toast.remove(); + }, 3000); +} \ No newline at end of file diff --git a/web/static/js/config.js b/web/static/js/config.js new file mode 100644 index 0000000..db06200 --- /dev/null +++ b/web/static/js/config.js @@ -0,0 +1,66 @@ +let editor = ace.edit("editor"); +editor.setTheme("ace/theme/monokai"); +editor.session.setMode("ace/mode/json"); +editor.setOptions({ + fontSize: "14px" +}); + +function showMessage(msg, isError = false) { + const msgDiv = document.getElementById('message'); + msgDiv.textContent = msg; + msgDiv.className = isError ? 'error' : 'success'; + msgDiv.style.display = 'block'; + setTimeout(() => { + msgDiv.style.display = 'none'; + }, 5000); +} + +async function loadConfig() { + try { + const response = await fetch('/metrics/config/get'); + if (!response.ok) { + throw new Error('加载配置失败'); + } + const config = await response.json(); + editor.setValue(JSON.stringify(config, null, 2), -1); + showMessage('配置已加载'); + } catch (error) { + showMessage(error.message, true); + } +} + +async function saveConfig() { + try { + const config = JSON.parse(editor.getValue()); + const response = await fetch('/metrics/config/save', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(config) + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(error); + } + + const result = await response.json(); + showMessage(result.message); + } catch (error) { + showMessage(error.message, true); + } +} + +function formatJson() { + try { + const config = JSON.parse(editor.getValue()); + editor.setValue(JSON.stringify(config, null, 2), -1); + showMessage('JSON已格式化'); + } catch (error) { + showMessage('JSON格式错误: ' + error.message, true); + } +} + +// 初始加载配置 +loadConfig(); \ No newline at end of file diff --git a/web/static/js/login.js b/web/static/js/login.js new file mode 100644 index 0000000..6d2d5e8 --- /dev/null +++ b/web/static/js/login.js @@ -0,0 +1,40 @@ +async function login() { + const password = document.getElementById('password').value; + + try { + const response = await fetch('/metrics/auth', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ password }) + }); + + if (!response.ok) { + throw new Error('登录失败'); + } + + const data = await response.json(); + localStorage.setItem('token', data.token); + window.location.href = '/metrics/dashboard'; + } catch (error) { + showMessage(error.message, true); + } +} + +function showMessage(msg, isError = false) { + const msgDiv = document.getElementById('message'); + msgDiv.textContent = msg; + msgDiv.className = isError ? 'error' : 'success'; + msgDiv.style.display = 'block'; + setTimeout(() => { + msgDiv.style.display = 'none'; + }, 5000); +} + +// 添加回车键监听 +document.getElementById('password').addEventListener('keypress', function(e) { + if (e.key === 'Enter') { + login(); + } +}); \ No newline at end of file diff --git a/web/static/js/metrics.js b/web/static/js/metrics.js new file mode 100644 index 0000000..151838b --- /dev/null +++ b/web/static/js/metrics.js @@ -0,0 +1,105 @@ +async function loadMetrics() { + try { + const token = localStorage.getItem('token'); + if (!token) { + window.location.href = '/metrics/ui'; + return; + } + + const response = await fetch('/metrics', { + headers: { + 'Authorization': `Bearer ${token}` + } + }); + + if (!response.ok) { + if (response.status === 401) { + window.location.href = '/metrics/ui'; + return; + } + throw new Error('加载监控数据失败'); + } + + const metrics = await response.json(); + displayMetrics(metrics); + } catch (error) { + showMessage(error.message, true); + } +} + +function displayMetrics(metrics) { + const container = document.getElementById('metrics'); + container.innerHTML = ''; + + // 添加基本信息 + addSection(container, '基本信息', { + '运行时间': metrics.uptime, + '总请求数': metrics.totalRequests, + '活跃请求数': metrics.activeRequests, + '错误请求数': metrics.totalErrors, + '总传输字节': formatBytes(metrics.totalBytes) + }); + + // 添加状态码统计 + addSection(container, '状态码统计', metrics.statusStats); + + // 添加路径统计 + addSection(container, '路径统计', metrics.pathStats); + + // 添加来源统计 + addSection(container, '来源统计', metrics.refererStats); + + // 添加延迟统计 + addSection(container, '延迟统计', { + '平均延迟': `${metrics.avgLatency}ms`, + '延迟分布': metrics.latencyBuckets + }); +} + +function addSection(container, title, data) { + const section = document.createElement('div'); + section.className = 'metrics-section'; + + const titleElem = document.createElement('h2'); + titleElem.textContent = title; + section.appendChild(titleElem); + + const content = document.createElement('div'); + content.className = 'metrics-content'; + + for (const [key, value] of Object.entries(data)) { + const item = document.createElement('div'); + item.className = 'metrics-item'; + item.innerHTML = `${key}: ${value}`; + content.appendChild(item); + } + + section.appendChild(content); + container.appendChild(section); +} + +function formatBytes(bytes) { + if (bytes === 0) return '0 B'; + const k = 1024; + const sizes = ['B', 'KB', 'MB', 'GB', 'TB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; +} + +function showMessage(msg, isError = false) { + const msgDiv = document.getElementById('message'); + if (!msgDiv) return; + + msgDiv.textContent = msg; + msgDiv.className = isError ? 'error' : 'success'; + msgDiv.style.display = 'block'; + setTimeout(() => { + msgDiv.style.display = 'none'; + }, 5000); +} + +// 初始加载监控数据 +loadMetrics(); + +// 每30秒刷新一次数据 +setInterval(loadMetrics, 30000); \ No newline at end of file diff --git a/web/templates/admin/config.html b/web/templates/admin/config.html new file mode 100644 index 0000000..7e2045e --- /dev/null +++ b/web/templates/admin/config.html @@ -0,0 +1,73 @@ +{{define "Content"}} +
+
+

配置管理

+
+ + + +
+
+
+
+ + + +{{end}} \ No newline at end of file diff --git a/web/templates/admin/layout.html b/web/templates/admin/layout.html new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/web/templates/admin/layout.html @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/templates/admin/metrics.html b/web/templates/admin/metrics.html new file mode 100644 index 0000000..6be15c0 --- /dev/null +++ b/web/templates/admin/metrics.html @@ -0,0 +1,216 @@ +{{define "Content"}} +
+
+
+

基础指标

+
+
+
运行时间
+
+
+
+
当前活跃请求
+
+
+
+
总请求数
+
+
+
+
错误数
+
+
+
+
+
+ +
+
+

系统指标

+
+
+
Goroutine数量
+
+
+
+
内存使用
+
+
+
+
平均响应时间
+
+
+
+
每秒请求数
+
+
+
+
+
+
+ +
+
+

状态码统计

+
+
+
+ +
+
+

热门路径 (Top 10)

+
+ + + + + + + + + + + +
路径请求数错误数平均延迟传输大小
+
+
+
+ +
+
+

最近请求

+
+ + + + + + + + + + + + +
时间路径状态延迟大小客户端IP
+
+
+
+ + +{{end}} \ No newline at end of file