diff --git a/data/config.json b/data/config.json index 1bc047e..a337443 100644 --- a/data/config.json +++ b/data/config.json @@ -33,5 +33,9 @@ "TargetHost": "cdn.jsdelivr.net", "TargetURL": "https://cdn.jsdelivr.net" } - ] + ], + "Metrics": { + "Password": "admin123", + "TokenExpiry": 86400 + } } \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index c9f6620..a8de295 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,4 +8,17 @@ services: - ./data:/app/data environment: - TZ=Asia/Shanghai - restart: always \ No newline at end of file + restart: always + deploy: + resources: + limits: + cpus: '1' + memory: 512M + reservations: + cpus: '0.25' + memory: 128M + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:80/"] + interval: 30s + timeout: 3s + retries: 3 \ No newline at end of file diff --git a/go.mod b/go.mod index 41429da..caca1c4 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module proxy-go go 1.23.1 -require github.com/andybalholm/brotli v1.1.1 +require ( + github.com/andybalholm/brotli v1.1.1 + golang.org/x/time v0.8.0 +) diff --git a/go.sum b/go.sum index f5064fa..69caaa3 100644 --- a/go.sum +++ b/go.sum @@ -2,3 +2,5 @@ github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7X github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/internal/config/config.go b/internal/config/config.go index a46b739..90d7b8f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,8 +3,29 @@ package config import ( "encoding/json" "os" + "sync/atomic" + "time" ) +type ConfigManager struct { + config atomic.Value + configPath string +} + +func NewConfigManager(path string) *ConfigManager { + cm := &ConfigManager{configPath: path} + cm.loadConfig() + go cm.watchConfig() + return cm +} + +func (cm *ConfigManager) watchConfig() { + ticker := time.NewTicker(30 * time.Second) + for range ticker.C { + cm.loadConfig() + } +} + func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { @@ -18,3 +39,16 @@ func Load(path string) (*Config, error) { return &config, nil } + +func (cm *ConfigManager) loadConfig() error { + config, err := Load(cm.configPath) + if err != nil { + return err + } + cm.config.Store(config) + return nil +} + +func (cm *ConfigManager) GetConfig() *Config { + return cm.config.Load().(*Config) +} diff --git a/internal/config/types.go b/internal/config/types.go index b2a1b9a..f4f841a 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -9,6 +9,7 @@ type Config struct { MAP map[string]PathConfig `json:"MAP"` // 改为使用PathConfig Compression CompressionConfig `json:"Compression"` FixedPaths []FixedPathConfig `json:"FixedPaths"` + Metrics MetricsConfig `json:"Metrics"` } type PathConfig struct { @@ -33,6 +34,11 @@ type FixedPathConfig struct { TargetURL string `json:"TargetURL"` } +type MetricsConfig struct { + Password string `json:"Password"` + TokenExpiry int `json:"TokenExpiry"` // token有效期(秒) +} + // 添加一个辅助方法来处理字符串到 PathConfig 的转换 func (c *Config) UnmarshalJSON(data []byte) error { // 创建一个临时结构来解析原始JSON @@ -40,6 +46,7 @@ func (c *Config) UnmarshalJSON(data []byte) error { MAP map[string]json.RawMessage `json:"MAP"` Compression CompressionConfig `json:"Compression"` FixedPaths []FixedPathConfig `json:"FixedPaths"` + Metrics MetricsConfig `json:"Metrics"` } var temp TempConfig @@ -75,6 +82,7 @@ func (c *Config) UnmarshalJSON(data []byte) error { // 复制其他字段 c.Compression = temp.Compression c.FixedPaths = temp.FixedPaths + c.Metrics = temp.Metrics return nil } diff --git a/internal/handler/auth.go b/internal/handler/auth.go new file mode 100644 index 0000000..0c751c0 --- /dev/null +++ b/internal/handler/auth.go @@ -0,0 +1,62 @@ +package handler + +import ( + "crypto/rand" + "encoding/base64" + "sync" + "time" +) + +type tokenInfo struct { + createdAt time.Time + expiresIn time.Duration +} + +type authManager struct { + tokens sync.Map +} + +func newAuthManager() *authManager { + am := &authManager{} + // 启动token清理goroutine + go am.cleanExpiredTokens() + return am +} + +func (am *authManager) generateToken() string { + b := make([]byte, 32) + rand.Read(b) + return base64.URLEncoding.EncodeToString(b) +} + +func (am *authManager) addToken(token string, expiry time.Duration) { + am.tokens.Store(token, tokenInfo{ + createdAt: time.Now(), + expiresIn: expiry, + }) +} + +func (am *authManager) validateToken(token string) bool { + if info, ok := am.tokens.Load(token); ok { + tokenInfo := info.(tokenInfo) + if time.Since(tokenInfo.createdAt) < tokenInfo.expiresIn { + return true + } + am.tokens.Delete(token) + } + return false +} + +func (am *authManager) cleanExpiredTokens() { + ticker := time.NewTicker(time.Hour) + for range ticker.C { + am.tokens.Range(func(key, value interface{}) bool { + token := key.(string) + info := value.(tokenInfo) + if time.Since(info.createdAt) >= info.expiresIn { + am.tokens.Delete(token) + } + return true + }) + } +} diff --git a/internal/handler/metrics.go b/internal/handler/metrics.go new file mode 100644 index 0000000..e23c4da --- /dev/null +++ b/internal/handler/metrics.go @@ -0,0 +1,406 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "runtime" + "strings" + "sync/atomic" + "time" +) + +type Metrics struct { + // 基础指标 + Uptime string `json:"uptime"` + ActiveRequests int64 `json:"active_requests"` + TotalRequests int64 `json:"total_requests"` + TotalErrors int64 `json:"total_errors"` + ErrorRate float64 `json:"error_rate"` + + // 系统指标 + NumGoroutine int `json:"num_goroutine"` + MemoryUsage string `json:"memory_usage"` + + // 性能指标 + AverageResponseTime string `json:"avg_response_time"` + RequestsPerSecond float64 `json:"requests_per_second"` + + // 新增字段 + TotalBytes int64 `json:"total_bytes"` + BytesPerSecond float64 `json:"bytes_per_second"` + StatusCodeStats map[string]int64 `json:"status_code_stats"` + LatencyPercentiles map[string]float64 `json:"latency_percentiles"` + TopPaths []PathMetrics `json:"top_paths"` + RecentRequests []RequestLog `json:"recent_requests"` +} + +type PathMetrics struct { + Path string `json:"path"` + RequestCount int64 `json:"request_count"` + ErrorCount int64 `json:"error_count"` + AvgLatency string `json:"avg_latency"` + BytesTransferred int64 `json:"bytes_transferred"` +} + +// 添加格式化字节的辅助函数 +func formatBytes(bytes uint64) string { + const ( + MB = 1024 * 1024 + KB = 1024 + ) + + switch { + case bytes >= MB: + return fmt.Sprintf("%.2f MB", float64(bytes)/MB) + case bytes >= KB: + return fmt.Sprintf("%.2f KB", float64(bytes)/KB) + default: + return fmt.Sprintf("%d Bytes", bytes) + } +} + +func (h *ProxyHandler) MetricsHandler(w http.ResponseWriter, r *http.Request) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // 获取状态码统计 + statusStats := make(map[string]int64) + for i, v := range h.metrics.statusStats { + statusStats[fmt.Sprintf("%dxx", i+1)] = v.Load() + } + + // 获取Top 10路径统计 + var pathMetrics []PathMetrics + h.metrics.pathStats.Range(func(key, value interface{}) bool { + stats := value.(*PathStats) + pathMetrics = append(pathMetrics, PathMetrics{ + Path: key.(string), + RequestCount: stats.requests.Load(), + ErrorCount: stats.errors.Load(), + AvgLatency: formatDuration(time.Duration(stats.latencySum.Load() / stats.requests.Load())), + BytesTransferred: stats.bytes.Load(), + }) + return len(pathMetrics) < 10 + }) + + // 获取最近的请求 + var recentReqs []RequestLog + h.recentRequests.RLock() + cursor := h.recentRequests.cursor.Load() + for i := 0; i < 10; i++ { + idx := (cursor - int64(i) + 1000) % 1000 + if h.recentRequests.items[idx] != nil { + recentReqs = append(recentReqs, *h.recentRequests.items[idx]) + } + } + h.recentRequests.RUnlock() + + metrics := Metrics{ + Uptime: time.Since(h.startTime).String(), + ActiveRequests: atomic.LoadInt64(&h.metrics.activeRequests), + TotalRequests: atomic.LoadInt64(&h.metrics.totalRequests), + TotalErrors: atomic.LoadInt64(&h.metrics.totalErrors), + ErrorRate: float64(h.metrics.totalErrors) / float64(h.metrics.totalRequests), + NumGoroutine: runtime.NumGoroutine(), + MemoryUsage: formatBytes(m.Alloc), + TotalBytes: h.metrics.totalBytes.Load(), + BytesPerSecond: float64(h.metrics.totalBytes.Load()) / time.Since(h.startTime).Seconds(), + StatusCodeStats: statusStats, + TopPaths: pathMetrics, + RecentRequests: recentReqs, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metrics) +} + +// 添加格式化时间的辅助函数 +func formatDuration(d time.Duration) string { + if d < time.Millisecond { + return fmt.Sprintf("%.2f μs", float64(d.Microseconds())) + } + if d < time.Second { + return fmt.Sprintf("%.2f ms", float64(d.Milliseconds())) + } + return fmt.Sprintf("%.2f s", d.Seconds()) +} + +// 修改模板,添加登录页面 +var loginTemplate = ` + + + + Proxy-Go Metrics Login + + + + +
+

Metrics Login

+
密码错误
+
+ +
+ +
+ + + + +` + +// 修改原有的 metricsTemplate,添加 token 检查 +var metricsTemplate = ` + + + + Proxy-Go Metrics + + + + + + + +` + +// 添加认证中间件 +func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" || !strings.HasPrefix(auth, "Bearer ") { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + token := strings.TrimPrefix(auth, "Bearer ") + if !h.auth.validateToken(token) { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + next(w, r) + } +} + +// 修改处理器 +func (h *ProxyHandler) MetricsPageHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte(loginTemplate)) +} + +func (h *ProxyHandler) MetricsDashboardHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write([]byte(metricsTemplate)) +} + +func (h *ProxyHandler) MetricsAuthHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + if req.Password != h.config.Metrics.Password { + http.Error(w, "Invalid password", http.StatusUnauthorized) + return + } + + token := h.auth.generateToken() + h.auth.addToken(token, time.Duration(h.config.Metrics.TokenExpiry)*time.Second) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": token, + }) +} diff --git a/internal/handler/mirror_proxy.go b/internal/handler/mirror_proxy.go index 16aa8aa..b790831 100644 --- a/internal/handler/mirror_proxy.go +++ b/internal/handler/mirror_proxy.go @@ -11,10 +11,23 @@ import ( "time" ) -type MirrorProxyHandler struct{} +type MirrorProxyHandler struct { + client *http.Client +} func NewMirrorProxyHandler() *MirrorProxyHandler { - return &MirrorProxyHandler{} + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + return &MirrorProxyHandler{ + client: &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + }, + } } func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -91,14 +104,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { proxyReq.Host = parsedURL.Host // 发送请求 - client := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - // 可以添加其他传输设置,如TLS配置等 - }, - Timeout: 30 * time.Second, - } - resp, err := client.Do(proxyReq) + resp, err := h.client.Do(proxyReq) if err != nil { http.Error(w, "Error forwarding request", http.StatusBadGateway) log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error forwarding request: %v", diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index b0e6e79..b1c3557 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -10,33 +10,103 @@ import ( "proxy-go/internal/config" "proxy-go/internal/utils" "strings" + "sync" + "sync/atomic" "time" + + "golang.org/x/time/rate" ) const ( defaultBufferSize = 32 * 1024 // 32KB ) +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, defaultBufferSize) + }, +} + type ProxyHandler struct { - pathMap map[string]config.PathConfig + pathMap map[string]config.PathConfig + client *http.Client + limiter *rate.Limiter + startTime time.Time + config *config.Config + auth *authManager + metrics struct { + activeRequests int64 + totalRequests int64 + totalErrors int64 + totalBytes atomic.Int64 // 总传输字节数 + pathStats sync.Map // 路径统计 map[string]*PathStats + statusStats [6]atomic.Int64 // HTTP状态码统计(1xx-5xx) + latencyBuckets [10]atomic.Int64 // 延迟分布(0-100ms, 100-200ms...) + } + recentRequests struct { + sync.RWMutex + items [1000]*RequestLog // 固定大小的环形缓冲区 + cursor atomic.Int64 // 当前位置 + } +} + +// 单个请求的统计信息 +type RequestLog struct { + Time time.Time + Path string + Status int + Latency time.Duration + BytesSent int64 + ClientIP string +} + +// 路径统计 +type PathStats struct { + requests atomic.Int64 + errors atomic.Int64 + bytes atomic.Int64 + latencySum atomic.Int64 } // 修改参数类型 -func NewProxyHandler(pathMap map[string]config.PathConfig) *ProxyHandler { +func NewProxyHandler(cfg *config.Config) *ProxyHandler { + transport := &http.Transport{ + MaxIdleConns: 100, // 最大空闲连接数 + MaxIdleConnsPerHost: 10, // 每个 host 的最大空闲连接数 + IdleConnTimeout: 90 * time.Second, // 空闲连接超时时间 + } + return &ProxyHandler{ - pathMap: pathMap, + pathMap: cfg.MAP, + client: &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + }, + limiter: rate.NewLimiter(rate.Limit(5000), 10000), + startTime: time.Now(), + config: cfg, + auth: newAuthManager(), } } func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - startTime := time.Now() + atomic.AddInt64(&h.metrics.activeRequests, 1) + atomic.AddInt64(&h.metrics.totalRequests, 1) + defer atomic.AddInt64(&h.metrics.activeRequests, -1) + + if !h.limiter.Allow() { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + + start := time.Now() // 处理根路径请求 if r.URL.Path == "/" { w.WriteHeader(http.StatusOK) fmt.Fprint(w, "Welcome to CZL proxy.") log.Printf("[%s] %s %s -> %d (root path) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(start)) return } @@ -55,7 +125,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if matchedPrefix == "" { http.NotFound(w, r) log.Printf("[%s] %s %s -> 404 (not found) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(start)) return } @@ -67,7 +137,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error decoding path", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) return } @@ -96,7 +166,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error parsing target URL", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) return } @@ -162,13 +232,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // 发送代理请求 - client := &http.Client{} - resp, err := client.Do(proxyReq) + resp, err := h.client.Do(proxyReq) if err != nil { http.Error(w, "Error forwarding request", http.StatusBadGateway) log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", - utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(start)) return } defer resp.Body.Close() @@ -181,49 +250,114 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 设置响应状态码 w.WriteHeader(resp.StatusCode) - // 使用流式传输复制响应体 - var bytesCopied int64 - if f, ok := w.(http.Flusher); ok { - buf := make([]byte, defaultBufferSize) - for { - n, rerr := resp.Body.Read(buf) - if n > 0 { - bytesCopied += int64(n) - _, werr := w.Write(buf[:n]) - if werr != nil { - log.Printf("Error writing response: %v", werr) - return - } - f.Flush() - } - if rerr == io.EOF { - break - } - if rerr != nil { - log.Printf("Error reading response: %v", rerr) - break - } - } - } else { - // 如果不支持 Flusher,使用普通的 io.Copy - bytesCopied, err = io.Copy(w, resp.Body) + // 根据响应大小选择不同的处理策略 + contentLength := resp.ContentLength + if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应 + // 直接读取到内存并一次性写入 + body, err := io.ReadAll(resp.Body) if err != nil { - log.Printf("Error copying response: %v", err) + http.Error(w, "Error reading response", http.StatusInternalServerError) + return } + written, _ := w.Write(body) + h.recordStats(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), r) + } else { + // 大响应使用流式传输 + var bytesCopied int64 + if f, ok := w.(http.Flusher); ok { + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) + for { + n, rerr := resp.Body.Read(buf) + if n > 0 { + bytesCopied += int64(n) + _, werr := w.Write(buf[:n]) + if werr != nil { + log.Printf("Error writing response: %v", werr) + return + } + f.Flush() + } + if rerr == io.EOF { + break + } + if rerr != nil { + log.Printf("Error reading response: %v", rerr) + break + } + } + } else { + // 如果不支持 Flusher,使用普通的 io.Copy + bytesCopied, err = io.Copy(w, resp.Body) + if err != nil { + log.Printf("Error copying response: %v", err) + } + } + + // 记录访问日志 + log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %-50s -> %s", + r.Method, // HTTP方法,左对齐占6位 + resp.StatusCode, // 状态码,占3位 + time.Since(start), // 处理时间,占12位 + utils.GetClientIP(r), // IP地址,占15位 + utils.FormatBytes(bytesCopied), // 传输大小,占10位 + utils.GetRequestSource(r), // 请求来源 + r.URL.Path, // 请求路径,左对齐占50位 + targetURL, // 目标URL + ) + + h.recordStats(r.URL.Path, resp.StatusCode, time.Since(start), bytesCopied, r) } - // 记录访问日志 - log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %-50s -> %s", - r.Method, // HTTP方法,左对齐占6位 - resp.StatusCode, // 状态码,占3位 - time.Since(startTime), // 处理时间,占12位 - utils.GetClientIP(r), // IP地址,占15位 - utils.FormatBytes(bytesCopied), // 传输大小,占10位 - utils.GetRequestSource(r), // 请求来源 - r.URL.Path, // 请求路径,左对齐占50位 - targetURL, // 目标URL - ) + if err != nil { + atomic.AddInt64(&h.metrics.totalErrors, 1) + } +} +func (h *ProxyHandler) recordStats(path string, status int, latency time.Duration, bytes int64, r *http.Request) { + // 更新总字节数 + h.metrics.totalBytes.Add(bytes) + + // 更新状态码统计 + if status >= 100 && status < 600 { + h.metrics.statusStats[status/100-1].Add(1) + } + + // 更新延迟分布 + bucket := int(latency.Milliseconds() / 100) + if bucket < 10 { + h.metrics.latencyBuckets[bucket].Add(1) + } + + // 更新路径统计 + if stats, ok := h.metrics.pathStats.Load(path); ok { + pathStats := stats.(*PathStats) + pathStats.requests.Add(1) + pathStats.bytes.Add(bytes) + pathStats.latencySum.Add(int64(latency)) + } else { + // 首次遇到该路径 + newStats := &PathStats{} + newStats.requests.Add(1) + newStats.bytes.Add(bytes) + newStats.latencySum.Add(int64(latency)) + h.metrics.pathStats.Store(path, newStats) + } + + // 记录最近的请求 + log := &RequestLog{ + Time: time.Now(), + Path: path, + Status: status, + Latency: latency, + BytesSent: bytes, + ClientIP: utils.GetClientIP(r), + } + + cursor := h.recentRequests.cursor.Add(1) % 1000 + h.recentRequests.Lock() + h.recentRequests.items[cursor] = log + h.recentRequests.Unlock() } func copyHeader(dst, src http.Header) { diff --git a/internal/middleware/compression.go b/internal/middleware/compression.go index ff0f6fb..bb8795e 100644 --- a/internal/middleware/compression.go +++ b/internal/middleware/compression.go @@ -8,6 +8,7 @@ import ( "net/http" "proxy-go/internal/compression" "strings" + "sync" ) const ( @@ -24,6 +25,12 @@ type CompressResponseWriter struct { compressed bool } +var writerPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriterSize(nil, defaultBufferSize) + }, +} + func CompressionMiddleware(manager compression.Manager) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -106,15 +113,21 @@ func (cw *CompressResponseWriter) Write(b []byte) (int, error) { // 延迟初始化压缩写入器 if cw.writer == nil { - var err error - cw.writer, err = cw.compressor.Compress(cw.ResponseWriter) + writer, err := cw.compressor.Compress(cw.ResponseWriter) if err != nil { return 0, err } - cw.bufferedWriter = bufio.NewWriterSize(cw.writer, defaultBufferSize) + cw.writer = writer + bw := writerPool.Get().(*bufio.Writer) + bw.Reset(writer) + cw.bufferedWriter = bw } - return cw.bufferedWriter.Write(b) + n, err := cw.bufferedWriter.Write(b) + if err != nil { + writerPool.Put(cw.bufferedWriter) + } + return n, err } // 实现 http.Hijacker 接口 diff --git a/main.go b/main.go index 03f6307..4172d39 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,7 @@ func main() { // 创建代理处理器 mirrorHandler := handler.NewMirrorProxyHandler() - proxyHandler := handler.NewProxyHandler(cfg.MAP) + proxyHandler := handler.NewProxyHandler(cfg) // 创建处理器链 handlers := []struct { @@ -80,6 +80,12 @@ func main() { handler = middleware.CompressionMiddleware(compManager)(handler) } + // 添加监控路由 + http.HandleFunc("/metrics", proxyHandler.AuthMiddleware(proxyHandler.MetricsHandler)) + http.HandleFunc("/metrics/ui", proxyHandler.MetricsPageHandler) + http.HandleFunc("/metrics/auth", proxyHandler.MetricsAuthHandler) + http.HandleFunc("/metrics/dashboard", proxyHandler.MetricsDashboardHandler) + // 创建服务器 server := &http.Server{ Addr: ":80",