diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index ea53a0b..8f06eb4 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -4,9 +4,9 @@ import ( "fmt" "io" "log" - "net" "net/http" "net/url" + "proxy-go/internal/utils" "strings" "time" ) @@ -33,7 +33,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, "Welcome to CZL proxy.") log.Printf("[%s] %s %s -> %d (root path) [%v]", - getClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime)) return } @@ -52,7 +52,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if matchedPrefix == "" { http.NotFound(w, r) log.Printf("[%s] %s %s -> 404 (not found) [%v]", - getClientIP(r), r.Method, r.URL.Path, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(startTime)) return } @@ -63,7 +63,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error decoding path", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } @@ -80,7 +80,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error parsing target URL", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } @@ -89,7 +89,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error creating proxy request", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } @@ -99,12 +99,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 设置必要的头部,使用目标站点的 Host proxyReq.Host = parsedURL.Host proxyReq.Header.Set("Host", parsedURL.Host) - proxyReq.Header.Set("X-Real-IP", getClientIP(r)) + proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r)) proxyReq.Header.Set("X-Forwarded-Host", r.Host) proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme) // 添加或更新 X-Forwarded-For - if clientIP := getClientIP(r); clientIP != "" { + if clientIP := utils.GetClientIP(r); clientIP != "" { if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" { proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP) } else { @@ -123,7 +123,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "Error forwarding request", http.StatusBadGateway) log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } defer resp.Body.Close() @@ -165,9 +165,9 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // 记录访问日志 - log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]", - getClientIP(r), r.Method, r.URL.Path, targetURL, - resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime)) + log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]", + utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL, + resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime)) } func copyHeader(dst, src http.Header) { @@ -177,32 +177,3 @@ func copyHeader(dst, src http.Header) { } } } - -func getClientIP(r *http.Request) string { - if ip := r.Header.Get("X-Real-IP"); ip != "" { - return ip - } - if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - return strings.Split(ip, ",")[0] - } - if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { - return ip - } - return r.RemoteAddr -} - -func formatBytes(bytes int64) string { - const ( - MB = 1024 * 1024 - KB = 1024 - ) - - switch { - case bytes >= MB: - return fmt.Sprintf("%.2f MB", float64(bytes)/MB) - case bytes >= KB: - return fmt.Sprintf("%.2f KB", float64(bytes)/KB) - default: - return fmt.Sprintf("%d Bytes", bytes) - } -} diff --git a/internal/middleware/fixed_path_proxy.go b/internal/middleware/fixed_path_proxy.go index 0d414ae..4fbf3c5 100644 --- a/internal/middleware/fixed_path_proxy.go +++ b/internal/middleware/fixed_path_proxy.go @@ -2,12 +2,11 @@ package middleware import ( "errors" - "fmt" "io" "log" - "net" "net/http" "proxy-go/internal/config" + "proxy-go/internal/utils" "strings" "syscall" "time" @@ -34,7 +33,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle if err != nil { http.Error(w, "Error creating proxy request", http.StatusInternalServerError) log.Printf("[%s] %s %s -> 500 (error creating request: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } @@ -48,7 +47,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle // 设置必要的头部 proxyReq.Host = cfg.TargetHost proxyReq.Header.Set("Host", cfg.TargetHost) - proxyReq.Header.Set("X-Real-IP", getClientIP(r)) + proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r)) proxyReq.Header.Set("X-Scheme", r.URL.Scheme) // 发送代理请求 @@ -57,7 +56,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle if err != nil { http.Error(w, "Error forwarding request", http.StatusBadGateway) log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", - getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) + utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) return } defer resp.Body.Close() @@ -75,13 +74,13 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle // 复制响应体 bytesCopied, err := io.Copy(w, resp.Body) if err := handleCopyError(err); err != nil { - log.Printf("[%s] Error copying response: %v", getClientIP(r), err) + log.Printf("[%s] Error copying response: %v", utils.GetClientIP(r), err) } // 记录成功的请求 - log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]", - getClientIP(r), r.Method, r.URL.Path, targetURL, - resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime)) + log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]", + utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL, + resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime)) return } @@ -93,20 +92,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle } } -// getClientIP 获取客户端IP -func getClientIP(r *http.Request) string { - if ip := r.Header.Get("X-Real-IP"); ip != "" { - return ip - } - if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - return strings.Split(ip, ",")[0] - } - if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { - return ip - } - return r.RemoteAddr -} - func handleCopyError(err error) error { if err == nil { return nil @@ -122,20 +107,3 @@ func handleCopyError(err error) error { return err } - -// formatBytes 将字节数转换为可读的格式(MB/KB/Bytes) -func formatBytes(bytes int64) string { - const ( - MB = 1024 * 1024 - KB = 1024 - ) - - switch { - case bytes >= MB: - return fmt.Sprintf("%.2f MB", float64(bytes)/MB) - case bytes >= KB: - return fmt.Sprintf("%.2f KB", float64(bytes)/KB) - default: - return fmt.Sprintf("%d Bytes", bytes) - } -} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..8984bc8 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,46 @@ +package utils + +import ( + "fmt" + "net" + "net/http" + "strings" +) + +func GetClientIP(r *http.Request) string { + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return strings.Split(ip, ",")[0] + } + if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return ip + } + return r.RemoteAddr +} + +// 获取请求来源 +func GetRequestSource(r *http.Request) string { + referer := r.Header.Get("Referer") + if referer != "" { + return fmt.Sprintf(" (from: %s)", referer) + } + return "" +} + +func FormatBytes(bytes int64) string { + const ( + MB = 1024 * 1024 + KB = 1024 + ) + + switch { + case bytes >= MB: + return fmt.Sprintf("%.2f MB", float64(bytes)/MB) + case bytes >= KB: + return fmt.Sprintf("%.2f KB", float64(bytes)/KB) + default: + return fmt.Sprintf("%d Bytes", bytes) + } +}