refactor(internal/handler, internal/middleware): Move getClientIP and formatBytes to utils package

This commit is contained in:
wood chen 2024-10-31 08:38:07 +08:00
parent c2c6e14736
commit 3c380ef5e9
3 changed files with 66 additions and 81 deletions

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/url" "net/url"
"proxy-go/internal/utils"
"strings" "strings"
"time" "time"
) )
@ -33,7 +33,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "Welcome to CZL proxy.") fmt.Fprint(w, "Welcome to CZL proxy.")
log.Printf("[%s] %s %s -> %d (root path) [%v]", log.Printf("[%s] %s %s -> %d (root path) [%v]",
getClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime))
return return
} }
@ -52,7 +52,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if matchedPrefix == "" { if matchedPrefix == "" {
http.NotFound(w, r) http.NotFound(w, r)
log.Printf("[%s] %s %s -> 404 (not found) [%v]", log.Printf("[%s] %s %s -> 404 (not found) [%v]",
getClientIP(r), r.Method, r.URL.Path, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(startTime))
return return
} }
@ -63,7 +63,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, "Error decoding path", http.StatusInternalServerError) http.Error(w, "Error decoding path", http.StatusInternalServerError)
log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]", log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
@ -80,7 +80,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, "Error parsing target URL", http.StatusInternalServerError) http.Error(w, "Error parsing target URL", http.StatusInternalServerError)
log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]", log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
@ -89,7 +89,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, "Error creating proxy request", http.StatusInternalServerError) http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
log.Printf("[%s] %s %s -> 500 (error: %v) [%v]", log.Printf("[%s] %s %s -> 500 (error: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
@ -99,12 +99,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 设置必要的头部,使用目标站点的 Host // 设置必要的头部,使用目标站点的 Host
proxyReq.Host = parsedURL.Host proxyReq.Host = parsedURL.Host
proxyReq.Header.Set("Host", parsedURL.Host) proxyReq.Header.Set("Host", parsedURL.Host)
proxyReq.Header.Set("X-Real-IP", getClientIP(r)) proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r))
proxyReq.Header.Set("X-Forwarded-Host", r.Host) proxyReq.Header.Set("X-Forwarded-Host", r.Host)
proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme) proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme)
// 添加或更新 X-Forwarded-For // 添加或更新 X-Forwarded-For
if clientIP := getClientIP(r); clientIP != "" { if clientIP := utils.GetClientIP(r); clientIP != "" {
if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" { if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" {
proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP) proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
} else { } else {
@ -123,7 +123,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, "Error forwarding request", http.StatusBadGateway) http.Error(w, "Error forwarding request", http.StatusBadGateway)
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -165,9 +165,9 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// 记录访问日志 // 记录访问日志
log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]", log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]",
getClientIP(r), r.Method, r.URL.Path, targetURL, utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL,
resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime)) resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime))
} }
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {
@ -177,32 +177,3 @@ func copyHeader(dst, src http.Header) {
} }
} }
} }
func getClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return ip
}
return r.RemoteAddr
}
func formatBytes(bytes int64) string {
const (
MB = 1024 * 1024
KB = 1024
)
switch {
case bytes >= MB:
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
case bytes >= KB:
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
default:
return fmt.Sprintf("%d Bytes", bytes)
}
}

View File

@ -2,12 +2,11 @@ package middleware
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"proxy-go/internal/config" "proxy-go/internal/config"
"proxy-go/internal/utils"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -34,7 +33,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
if err != nil { if err != nil {
http.Error(w, "Error creating proxy request", http.StatusInternalServerError) http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
log.Printf("[%s] %s %s -> 500 (error creating request: %v) [%v]", log.Printf("[%s] %s %s -> 500 (error creating request: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
@ -48,7 +47,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
// 设置必要的头部 // 设置必要的头部
proxyReq.Host = cfg.TargetHost proxyReq.Host = cfg.TargetHost
proxyReq.Header.Set("Host", cfg.TargetHost) proxyReq.Header.Set("Host", cfg.TargetHost)
proxyReq.Header.Set("X-Real-IP", getClientIP(r)) proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r))
proxyReq.Header.Set("X-Scheme", r.URL.Scheme) proxyReq.Header.Set("X-Scheme", r.URL.Scheme)
// 发送代理请求 // 发送代理请求
@ -57,7 +56,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
if err != nil { if err != nil {
http.Error(w, "Error forwarding request", http.StatusBadGateway) http.Error(w, "Error forwarding request", http.StatusBadGateway)
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]", log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime)) utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -75,13 +74,13 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
// 复制响应体 // 复制响应体
bytesCopied, err := io.Copy(w, resp.Body) bytesCopied, err := io.Copy(w, resp.Body)
if err := handleCopyError(err); err != nil { if err := handleCopyError(err); err != nil {
log.Printf("[%s] Error copying response: %v", getClientIP(r), err) log.Printf("[%s] Error copying response: %v", utils.GetClientIP(r), err)
} }
// 记录成功的请求 // 记录成功的请求
log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]", log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]",
getClientIP(r), r.Method, r.URL.Path, targetURL, utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL,
resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime)) resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime))
return return
} }
@ -93,20 +92,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
} }
} }
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return ip
}
return r.RemoteAddr
}
func handleCopyError(err error) error { func handleCopyError(err error) error {
if err == nil { if err == nil {
return nil return nil
@ -122,20 +107,3 @@ func handleCopyError(err error) error {
return err return err
} }
// formatBytes 将字节数转换为可读的格式MB/KB/Bytes
func formatBytes(bytes int64) string {
const (
MB = 1024 * 1024
KB = 1024
)
switch {
case bytes >= MB:
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
case bytes >= KB:
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
default:
return fmt.Sprintf("%d Bytes", bytes)
}
}

46
internal/utils/utils.go Normal file
View File

@ -0,0 +1,46 @@
package utils
import (
"fmt"
"net"
"net/http"
"strings"
)
func GetClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return ip
}
return r.RemoteAddr
}
// 获取请求来源
func GetRequestSource(r *http.Request) string {
referer := r.Header.Get("Referer")
if referer != "" {
return fmt.Sprintf(" (from: %s)", referer)
}
return ""
}
func FormatBytes(bytes int64) string {
const (
MB = 1024 * 1024
KB = 1024
)
switch {
case bytes >= MB:
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
case bytes >= KB:
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
default:
return fmt.Sprintf("%d Bytes", bytes)
}
}