From 2f0be5f38d7bcb9cb918c11f8a22e54440f49990 Mon Sep 17 00:00:00 2001 From: wood chen Date: Wed, 13 Nov 2024 17:50:06 +0800 Subject: [PATCH] feat(handler): add mirror proxy handler for URL prefix "/mirror/" --- internal/handler/mirror_proxy.go | 85 ++++++++++++++++++++++++++++++++ main.go | 8 +++ 2 files changed, 93 insertions(+) create mode 100644 internal/handler/mirror_proxy.go diff --git a/internal/handler/mirror_proxy.go b/internal/handler/mirror_proxy.go new file mode 100644 index 0000000..19779da --- /dev/null +++ b/internal/handler/mirror_proxy.go @@ -0,0 +1,85 @@ +// handler/mirror_proxy.go +package handler + +import ( + "io" + "log" + "net/http" + "proxy-go/internal/utils" + "strings" + "time" +) + +type MirrorProxyHandler struct{} + +func NewMirrorProxyHandler() *MirrorProxyHandler { + return &MirrorProxyHandler{} +} + +func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // 从路径中提取实际URL + // 例如:/mirror/https://example.com/path 变成 https://example.com/path + actualURL := strings.TrimPrefix(r.URL.Path, "/mirror/") + if actualURL == "" || actualURL == r.URL.Path { + http.Error(w, "Invalid URL", http.StatusBadRequest) + return + } + + // 添加原始请求中的查询参数和片段 + if r.URL.RawQuery != "" { + actualURL += "?" + r.URL.RawQuery + } + + // 创建新的请求 + proxyReq, err := http.NewRequest(r.Method, actualURL, r.Body) + if err != nil { + http.Error(w, "Error creating request", http.StatusInternalServerError) + log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error creating request: %v", + r.Method, http.StatusInternalServerError, time.Since(startTime), + utils.GetClientIP(r), "-", actualURL, err) + return + } + + // 复制原始请求的header + copyHeader(proxyReq.Header, r.Header) + + // 发送请求 + client := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + }, + } + resp, err := client.Do(proxyReq) + if err != nil { + http.Error(w, "Error forwarding request", http.StatusBadGateway) + log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error forwarding request: %v", + r.Method, http.StatusBadGateway, time.Since(startTime), + utils.GetClientIP(r), "-", actualURL, err) + return + } + defer resp.Body.Close() + + // 设置CORS头 + w.Header().Set("Access-Control-Allow-Origin", "*") + + // 复制响应头 + copyHeader(w.Header(), resp.Header) + + // 设置状态码 + w.WriteHeader(resp.StatusCode) + + // 复制响应体 + bytesCopied, err := io.Copy(w, resp.Body) + if err != nil { + log.Printf("Error copying response: %v", err) + return + } + + // 记录访问日志 + log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s", + r.Method, resp.StatusCode, time.Since(startTime), + utils.GetClientIP(r), utils.FormatBytes(bytesCopied), + utils.GetRequestSource(r), actualURL) +} diff --git a/main.go b/main.go index f279ba6..03f6307 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,7 @@ func main() { }) // 创建代理处理器 + mirrorHandler := handler.NewMirrorProxyHandler() proxyHandler := handler.NewProxyHandler(cfg.MAP) // 创建处理器链 @@ -34,6 +35,13 @@ func main() { matcher func(*http.Request) bool handler http.Handler }{ + // Mirror代理处理器 + { + matcher: func(r *http.Request) bool { + return strings.HasPrefix(r.URL.Path, "/mirror/") + }, + handler: mirrorHandler, + }, // 固定路径处理器 { matcher: func(r *http.Request) bool {