mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 08:31:55 +08:00
refactor(internal/handler, internal/middleware): Move getClientIP and formatBytes to utils package
This commit is contained in:
parent
c2c6e14736
commit
3c380ef5e9
@ -4,9 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"proxy-go/internal/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -33,7 +33,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprint(w, "Welcome to CZL proxy.")
|
fmt.Fprint(w, "Welcome to CZL proxy.")
|
||||||
log.Printf("[%s] %s %s -> %d (root path) [%v]",
|
log.Printf("[%s] %s %s -> %d (root path) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, http.StatusOK, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if matchedPrefix == "" {
|
if matchedPrefix == "" {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
log.Printf("[%s] %s %s -> 404 (not found) [%v]",
|
log.Printf("[%s] %s %s -> 404 (not found) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error decoding path", http.StatusInternalServerError)
|
http.Error(w, "Error decoding path", http.StatusInternalServerError)
|
||||||
log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]",
|
log.Printf("[%s] %s %s -> 500 (error decoding path: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error parsing target URL", http.StatusInternalServerError)
|
http.Error(w, "Error parsing target URL", http.StatusInternalServerError)
|
||||||
log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]",
|
log.Printf("[%s] %s %s -> 500 (error parsing URL: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
|
http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
|
||||||
log.Printf("[%s] %s %s -> 500 (error: %v) [%v]",
|
log.Printf("[%s] %s %s -> 500 (error: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,12 +99,12 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// 设置必要的头部,使用目标站点的 Host
|
// 设置必要的头部,使用目标站点的 Host
|
||||||
proxyReq.Host = parsedURL.Host
|
proxyReq.Host = parsedURL.Host
|
||||||
proxyReq.Header.Set("Host", parsedURL.Host)
|
proxyReq.Header.Set("Host", parsedURL.Host)
|
||||||
proxyReq.Header.Set("X-Real-IP", getClientIP(r))
|
proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r))
|
||||||
proxyReq.Header.Set("X-Forwarded-Host", r.Host)
|
proxyReq.Header.Set("X-Forwarded-Host", r.Host)
|
||||||
proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme)
|
proxyReq.Header.Set("X-Forwarded-Proto", r.URL.Scheme)
|
||||||
|
|
||||||
// 添加或更新 X-Forwarded-For
|
// 添加或更新 X-Forwarded-For
|
||||||
if clientIP := getClientIP(r); clientIP != "" {
|
if clientIP := utils.GetClientIP(r); clientIP != "" {
|
||||||
if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" {
|
if prior := proxyReq.Header.Get("X-Forwarded-For"); prior != "" {
|
||||||
proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
|
proxyReq.Header.Set("X-Forwarded-For", prior+", "+clientIP)
|
||||||
} else {
|
} else {
|
||||||
@ -123,7 +123,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error forwarding request", http.StatusBadGateway)
|
http.Error(w, "Error forwarding request", http.StatusBadGateway)
|
||||||
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
|
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -165,9 +165,9 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 记录访问日志
|
// 记录访问日志
|
||||||
log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]",
|
log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, targetURL,
|
utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL,
|
||||||
resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime))
|
resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyHeader(dst, src http.Header) {
|
func copyHeader(dst, src http.Header) {
|
||||||
@ -177,32 +177,3 @@ func copyHeader(dst, src http.Header) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getClientIP(r *http.Request) string {
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
||||||
return strings.Split(ip, ",")[0]
|
|
||||||
}
|
|
||||||
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatBytes(bytes int64) string {
|
|
||||||
const (
|
|
||||||
MB = 1024 * 1024
|
|
||||||
KB = 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case bytes >= MB:
|
|
||||||
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
|
|
||||||
case bytes >= KB:
|
|
||||||
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%d Bytes", bytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -2,12 +2,11 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"proxy-go/internal/config"
|
"proxy-go/internal/config"
|
||||||
|
"proxy-go/internal/utils"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -34,7 +33,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
|
http.Error(w, "Error creating proxy request", http.StatusInternalServerError)
|
||||||
log.Printf("[%s] %s %s -> 500 (error creating request: %v) [%v]",
|
log.Printf("[%s] %s %s -> 500 (error creating request: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
// 设置必要的头部
|
// 设置必要的头部
|
||||||
proxyReq.Host = cfg.TargetHost
|
proxyReq.Host = cfg.TargetHost
|
||||||
proxyReq.Header.Set("Host", cfg.TargetHost)
|
proxyReq.Header.Set("Host", cfg.TargetHost)
|
||||||
proxyReq.Header.Set("X-Real-IP", getClientIP(r))
|
proxyReq.Header.Set("X-Real-IP", utils.GetClientIP(r))
|
||||||
proxyReq.Header.Set("X-Scheme", r.URL.Scheme)
|
proxyReq.Header.Set("X-Scheme", r.URL.Scheme)
|
||||||
|
|
||||||
// 发送代理请求
|
// 发送代理请求
|
||||||
@ -57,7 +56,7 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Error forwarding request", http.StatusBadGateway)
|
http.Error(w, "Error forwarding request", http.StatusBadGateway)
|
||||||
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
|
log.Printf("[%s] %s %s -> 502 (proxy error: %v) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
utils.GetClientIP(r), r.Method, r.URL.Path, err, time.Since(startTime))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -75,13 +74,13 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
// 复制响应体
|
// 复制响应体
|
||||||
bytesCopied, err := io.Copy(w, resp.Body)
|
bytesCopied, err := io.Copy(w, resp.Body)
|
||||||
if err := handleCopyError(err); err != nil {
|
if err := handleCopyError(err); err != nil {
|
||||||
log.Printf("[%s] Error copying response: %v", getClientIP(r), err)
|
log.Printf("[%s] Error copying response: %v", utils.GetClientIP(r), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录成功的请求
|
// 记录成功的请求
|
||||||
log.Printf("[%s] %s %s -> %s -> %d (%s) [%v]",
|
log.Printf("[%s] %s %s%s -> %s -> %d (%s) [%v]",
|
||||||
getClientIP(r), r.Method, r.URL.Path, targetURL,
|
utils.GetClientIP(r), r.Method, r.URL.Path, utils.GetRequestSource(r), targetURL,
|
||||||
resp.StatusCode, formatBytes(bytesCopied), time.Since(startTime))
|
resp.StatusCode, utils.FormatBytes(bytesCopied), time.Since(startTime))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -93,20 +92,6 @@ func FixedPathProxyMiddleware(configs []config.FixedPathConfig) func(http.Handle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientIP 获取客户端IP
|
|
||||||
func getClientIP(r *http.Request) string {
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
||||||
return strings.Split(ip, ",")[0]
|
|
||||||
}
|
|
||||||
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleCopyError(err error) error {
|
func handleCopyError(err error) error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -122,20 +107,3 @@ func handleCopyError(err error) error {
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatBytes 将字节数转换为可读的格式(MB/KB/Bytes)
|
|
||||||
func formatBytes(bytes int64) string {
|
|
||||||
const (
|
|
||||||
MB = 1024 * 1024
|
|
||||||
KB = 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case bytes >= MB:
|
|
||||||
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
|
|
||||||
case bytes >= KB:
|
|
||||||
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%d Bytes", bytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
46
internal/utils/utils.go
Normal file
46
internal/utils/utils.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetClientIP(r *http.Request) string {
|
||||||
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||||
|
return strings.Split(ip, ",")[0]
|
||||||
|
}
|
||||||
|
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求来源
|
||||||
|
func GetRequestSource(r *http.Request) string {
|
||||||
|
referer := r.Header.Get("Referer")
|
||||||
|
if referer != "" {
|
||||||
|
return fmt.Sprintf(" (from: %s)", referer)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func FormatBytes(bytes int64) string {
|
||||||
|
const (
|
||||||
|
MB = 1024 * 1024
|
||||||
|
KB = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case bytes >= MB:
|
||||||
|
return fmt.Sprintf("%.2f MB", float64(bytes)/MB)
|
||||||
|
case bytes >= KB:
|
||||||
|
return fmt.Sprintf("%.2f KB", float64(bytes)/KB)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d Bytes", bytes)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user