refactor(proxy): Improve error handling and response writing in proxy handler

- Modify error response writing to use explicit WriteHeader and Write methods
- Remove redundant WriteHeader call in ServeHTTP method
- Add WriteHeader before writing response for both small and large responses
- Enhance code readability and error handling consistency
This commit is contained in:
wood chen 2025-02-15 08:34:14 +08:00
parent 53b16f44fb
commit 6c7bc2bfb8

View File

@ -228,19 +228,19 @@ type ProxyHandler struct {
rateLimiter *rateLimitManager rateLimiter *rateLimitManager
} }
// 修改参数类型 // NewProxyHandler 创建新的代理处理器
func NewProxyHandler(cfg *config.Config) *ProxyHandler { func NewProxyHandler(cfg *config.Config) *ProxyHandler {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: clientConnTimeout, // 客户端连接超时 Timeout: clientConnTimeout,
KeepAlive: 30 * time.Second, // TCP keepalive 间隔 KeepAlive: 30 * time.Second,
} }
transport := &http.Transport{ transport := &http.Transport{
DialContext: dialer.DialContext, DialContext: dialer.DialContext,
MaxIdleConns: 200, MaxIdleConns: 200,
MaxIdleConnsPerHost: 20, MaxIdleConnsPerHost: 20,
IdleConnTimeout: idleConnTimeout, // 空闲连接超时 IdleConnTimeout: idleConnTimeout,
TLSHandshakeTimeout: tlsHandshakeTimeout, // TLS握手超时 TLSHandshakeTimeout: tlsHandshakeTimeout,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
MaxConnsPerHost: 50, MaxConnsPerHost: 50,
DisableKeepAlives: false, DisableKeepAlives: false,
@ -248,14 +248,14 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler {
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
WriteBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024,
ReadBufferSize: 64 * 1024, ReadBufferSize: 64 * 1024,
ResponseHeaderTimeout: backendServTimeout, // 后端服务响应头超时 ResponseHeaderTimeout: backendServTimeout,
} }
return &ProxyHandler{ return &ProxyHandler{
pathMap: cfg.MAP, pathMap: cfg.MAP,
client: &http.Client{ client: &http.Client{
Transport: transport, Transport: transport,
Timeout: proxyRespTimeout, // 整体代理响应超时 Timeout: proxyRespTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 { if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects") return fmt.Errorf("stopped after 10 redirects")
@ -271,9 +271,11 @@ func NewProxyHandler(cfg *config.Config) *ProxyHandler {
errorHandler: func(w http.ResponseWriter, r *http.Request, err error) { errorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("[Error] %s %s -> %v", r.Method, r.URL.Path, err) log.Printf("[Error] %s %s -> %v", r.Method, r.URL.Path, err)
if strings.Contains(err.Error(), "rate limit exceeded") { if strings.Contains(err.Error(), "rate limit exceeded") {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Too Many Requests"))
} else { } else {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error"))
} }
}, },
} }
@ -465,9 +467,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 删除严格的 CSP // 删除严格的 CSP
w.Header().Del("Content-Security-Policy") w.Header().Del("Content-Security-Policy")
// 设置响应状态码
w.WriteHeader(resp.StatusCode)
// 根据响应大小选择不同的处理策略 // 根据响应大小选择不同的处理策略
contentLength := resp.ContentLength contentLength := resp.ContentLength
if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应 if contentLength > 0 && contentLength < 1<<20 { // 1MB 以下的小响应
@ -482,7 +481,8 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// 一次性写入响应 // 设置响应状态码并一次性写入响应
w.WriteHeader(resp.StatusCode)
written, err := w.Write(buf.Bytes()) written, err := w.Write(buf.Bytes())
if err != nil { if err != nil {
if !isConnectionClosed(err) { if !isConnectionClosed(err) {
@ -492,6 +492,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), int64(written), utils.GetClientIP(r), r)
} else { } else {
// 大响应使用零拷贝传输 // 大响应使用零拷贝传输
w.WriteHeader(resp.StatusCode)
var bytesCopied int64 var bytesCopied int64
var err error var err error