Compare commits

...

2 Commits

16 changed files with 1161 additions and 54 deletions

View File

@ -57,5 +57,14 @@
"Enabled": false, "Enabled": false,
"Level": 4 "Level": 4
} }
},
"Security": {
"IPBan": {
"Enabled": true,
"ErrorThreshold": 10,
"WindowMinutes": 5,
"BanDurationMinutes": 5,
"CleanupIntervalMinutes": 1
}
} }
} }

9
go.mod
View File

@ -1,10 +1,15 @@
module proxy-go module proxy-go
go 1.23.1 go 1.24
toolchain go1.24.4
require ( require (
github.com/andybalholm/brotli v1.1.1 github.com/andybalholm/brotli v1.1.1
golang.org/x/net v0.40.0 golang.org/x/net v0.40.0
) )
require golang.org/x/text v0.25.0 // indirect require (
github.com/woodchen-ink/go-web-utils v1.0.0 // indirect
golang.org/x/text v0.25.0 // indirect
)

4
go.sum
View File

@ -1,5 +1,9 @@
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/woodchen-ink/go-web-utils v0.0.0-20250621140947-08c57486fe2e h1:k/D90giyDyL5hDPJGGQexqZ423WmZqRUUxc/yQ6E8ws=
github.com/woodchen-ink/go-web-utils v0.0.0-20250621140947-08c57486fe2e/go.mod h1:d+L8rZ7xekLnf679XRvfwqpl4M8RCNdWSViaB3GmpnI=
github.com/woodchen-ink/go-web-utils v1.0.0 h1:Kybe0ZPhRI4w5FJ4bZdPcepNEKTmbw3to3xLR31e+ws=
github.com/woodchen-ink/go-web-utils v1.0.0/go.mod h1:hpiT30rd5Egj2LqRwYBqbEtUXjhjh/Qary0S14KCZgw=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=

View File

@ -7,6 +7,7 @@ import (
type Config struct { type Config struct {
MAP map[string]PathConfig `json:"MAP"` // 路径映射配置 MAP map[string]PathConfig `json:"MAP"` // 路径映射配置
Compression CompressionConfig `json:"Compression"` Compression CompressionConfig `json:"Compression"`
Security SecurityConfig `json:"Security"` // 安全配置
} }
type PathConfig struct { type PathConfig struct {
@ -36,6 +37,18 @@ type CompressorConfig struct {
Level int `json:"Level"` Level int `json:"Level"`
} }
type SecurityConfig struct {
IPBan IPBanConfig `json:"IPBan"` // IP封禁配置
}
type IPBanConfig struct {
Enabled bool `json:"Enabled"` // 是否启用IP封禁
ErrorThreshold int `json:"ErrorThreshold"` // 404错误阈值
WindowMinutes int `json:"WindowMinutes"` // 统计窗口时间(分钟)
BanDurationMinutes int `json:"BanDurationMinutes"` // 封禁时长(分钟)
CleanupIntervalMinutes int `json:"CleanupIntervalMinutes"` // 清理间隔(分钟)
}
// 扩展名映射配置结构 // 扩展名映射配置结构
type ExtRuleConfig struct { type ExtRuleConfig struct {
Extensions string `json:"Extensions"` // 逗号分隔的扩展名 Extensions string `json:"Extensions"` // 逗号分隔的扩展名

View File

@ -14,6 +14,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/woodchen-ink/go-web-utils/iputil"
) )
const ( const (
@ -154,7 +156,7 @@ func (h *ProxyHandler) CheckAuth(token string) bool {
func (h *ProxyHandler) LogoutHandler(w http.ResponseWriter, r *http.Request) { func (h *ProxyHandler) LogoutHandler(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
if auth == "" || !strings.HasPrefix(auth, "Bearer ") { if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
log.Printf("[Auth] ERR %s %s -> 401 (%s) no token from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 401 (%s) no token from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@ -162,7 +164,7 @@ func (h *ProxyHandler) LogoutHandler(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
h.auth.tokens.Delete(token) h.auth.tokens.Delete(token)
log.Printf("[Auth] %s %s -> 200 (%s) logout success from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] %s %s -> 200 (%s) logout success from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{ json.NewEncoder(w).Encode(map[string]string{
@ -175,14 +177,14 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
if auth == "" || !strings.HasPrefix(auth, "Bearer ") { if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
log.Printf("[Auth] ERR %s %s -> 401 (%s) no token from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 401 (%s) no token from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
token := strings.TrimPrefix(auth, "Bearer ") token := strings.TrimPrefix(auth, "Bearer ")
if !h.auth.validateToken(token) { if !h.auth.validateToken(token) {
log.Printf("[Auth] ERR %s %s -> 401 (%s) invalid token from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 401 (%s) invalid token from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Invalid token", http.StatusUnauthorized) http.Error(w, "Invalid token", http.StatusUnauthorized)
return return
} }
@ -253,14 +255,14 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 验证 state // 验证 state
if !h.auth.validateState(state) { if !h.auth.validateState(state) {
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s", log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r)) r.Method, r.URL.Path, iputil.GetClientIP(r), state, utils.GetRequestSource(r))
http.Error(w, "Invalid state", http.StatusBadRequest) http.Error(w, "Invalid state", http.StatusBadRequest)
return return
} }
// 验证code参数 // 验证code参数
if code == "" { if code == "" {
log.Printf("[Auth] ERR %s %s -> 400 (%s) missing code parameter from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 400 (%s) missing code parameter from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Missing code parameter", http.StatusBadRequest) http.Error(w, "Missing code parameter", http.StatusBadRequest)
return return
} }
@ -272,7 +274,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 验证OAuth配置 // 验证OAuth配置
if clientID == "" || clientSecret == "" { if clientID == "" || clientSecret == "" {
log.Printf("[Auth] ERR %s %s -> 500 (%s) missing OAuth credentials from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 500 (%s) missing OAuth credentials from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Server configuration error", http.StatusInternalServerError) http.Error(w, "Server configuration error", http.StatusInternalServerError)
return return
} }
@ -290,7 +292,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
"client_secret": {clientSecret}, "client_secret": {clientSecret},
}) })
if err != nil { if err != nil {
log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to get access token: %v from %s", r.Method, r.URL.Path, utils.GetClientIP(r), err, utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to get access token: %v from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), err, utils.GetRequestSource(r))
http.Error(w, "Failed to get access token", http.StatusInternalServerError) http.Error(w, "Failed to get access token", http.StatusInternalServerError)
return return
} }
@ -301,21 +303,21 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 读取错误响应内容 // 读取错误响应内容
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s", log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s",
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes)) r.Method, r.URL.Path, resp.StatusCode, iputil.GetClientIP(r), resp.Status, string(bodyBytes))
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError) http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
return return
} }
var token OAuthToken var token OAuthToken
if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { if err := json.NewDecoder(resp.Body).Decode(&token); err != nil {
log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to parse token response: %v from %s", r.Method, r.URL.Path, utils.GetClientIP(r), err, utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to parse token response: %v from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), err, utils.GetRequestSource(r))
http.Error(w, "Failed to parse token response", http.StatusInternalServerError) http.Error(w, "Failed to parse token response", http.StatusInternalServerError)
return return
} }
// 验证访问令牌 // 验证访问令牌
if token.AccessToken == "" { if token.AccessToken == "" {
log.Printf("[Auth] ERR %s %s -> 500 (%s) received empty access token from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 500 (%s) received empty access token from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Received invalid token", http.StatusInternalServerError) http.Error(w, "Received invalid token", http.StatusInternalServerError)
return return
} }
@ -326,7 +328,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
client := &http.Client{Timeout: 10 * time.Second} client := &http.Client{Timeout: 10 * time.Second}
userResp, err := client.Do(req) userResp, err := client.Do(req)
if err != nil { if err != nil {
log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to get user info: %v from %s", r.Method, r.URL.Path, utils.GetClientIP(r), err, utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to get user info: %v from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), err, utils.GetRequestSource(r))
http.Error(w, "Failed to get user info", http.StatusInternalServerError) http.Error(w, "Failed to get user info", http.StatusInternalServerError)
return return
} }
@ -335,7 +337,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 检查用户信息响应状态码 // 检查用户信息响应状态码
if userResp.StatusCode != http.StatusOK { if userResp.StatusCode != http.StatusOK {
log.Printf("[Auth] ERR %s %s -> %d (%s) userinfo endpoint returned error status: %s from %s", log.Printf("[Auth] ERR %s %s -> %d (%s) userinfo endpoint returned error status: %s from %s",
r.Method, r.URL.Path, userResp.StatusCode, utils.GetClientIP(r), userResp.Status, utils.GetRequestSource(r)) r.Method, r.URL.Path, userResp.StatusCode, iputil.GetClientIP(r), userResp.Status, utils.GetRequestSource(r))
http.Error(w, "Failed to get user info: "+userResp.Status, http.StatusInternalServerError) http.Error(w, "Failed to get user info: "+userResp.Status, http.StatusInternalServerError)
return return
} }
@ -344,7 +346,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
bodyBytes, err := io.ReadAll(userResp.Body) bodyBytes, err := io.ReadAll(userResp.Body)
if err != nil { if err != nil {
log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to read user info response body: %v from %s", log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to read user info response body: %v from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), err, utils.GetRequestSource(r)) r.Method, r.URL.Path, iputil.GetClientIP(r), err, utils.GetRequestSource(r))
http.Error(w, "Failed to read user info response", http.StatusInternalServerError) http.Error(w, "Failed to read user info response", http.StatusInternalServerError)
return return
} }
@ -356,7 +358,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
var rawUserInfo map[string]interface{} var rawUserInfo map[string]interface{}
if err := json.Unmarshal(bodyBytes, &rawUserInfo); err != nil { if err := json.Unmarshal(bodyBytes, &rawUserInfo); err != nil {
log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to parse raw user info: %v from %s", log.Printf("[Auth] ERR %s %s -> 500 (%s) failed to parse raw user info: %v from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), err, utils.GetRequestSource(r)) r.Method, r.URL.Path, iputil.GetClientIP(r), err, utils.GetRequestSource(r))
http.Error(w, "Failed to parse user info", http.StatusInternalServerError) http.Error(w, "Failed to parse user info", http.StatusInternalServerError)
return return
} }
@ -391,7 +393,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 验证用户信息 // 验证用户信息
if userInfo.Username == "" { if userInfo.Username == "" {
log.Printf("[Auth] ERR %s %s -> 500 (%s) could not extract username from user info from %s", log.Printf("[Auth] ERR %s %s -> 500 (%s) could not extract username from user info from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
http.Error(w, "Invalid user information: missing username", http.StatusInternalServerError) http.Error(w, "Invalid user information: missing username", http.StatusInternalServerError)
return return
} }
@ -400,7 +402,7 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
internalToken := h.auth.generateToken() internalToken := h.auth.generateToken()
h.auth.addToken(internalToken, userInfo.Username, tokenExpiry) h.auth.addToken(internalToken, userInfo.Username, tokenExpiry)
log.Printf("[Auth] %s %s -> 200 (%s) login success for user %s from %s", r.Method, r.URL.Path, utils.GetClientIP(r), userInfo.Username, utils.GetRequestSource(r)) log.Printf("[Auth] %s %s -> 200 (%s) login success for user %s from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), userInfo.Username, utils.GetRequestSource(r))
// 返回登录成功页面 // 返回登录成功页面
w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Type", "text/html")

View File

@ -11,6 +11,8 @@ import (
"proxy-go/internal/utils" "proxy-go/internal/utils"
"strings" "strings"
"time" "time"
"github.com/woodchen-ink/go-web-utils/iputil"
) )
type MirrorProxyHandler struct { type MirrorProxyHandler struct {
@ -56,7 +58,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | CORS Preflight", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | CORS Preflight",
r.Method, http.StatusOK, time.Since(startTime), r.Method, http.StatusOK, time.Since(startTime),
utils.GetClientIP(r), "-", r.URL.Path) iputil.GetClientIP(r), "-", r.URL.Path)
return return
} }
@ -66,7 +68,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid URL", http.StatusBadRequest) http.Error(w, "Invalid URL", http.StatusBadRequest)
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Invalid URL", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Invalid URL",
r.Method, http.StatusBadRequest, time.Since(startTime), r.Method, http.StatusBadRequest, time.Since(startTime),
utils.GetClientIP(r), "-", r.URL.Path) iputil.GetClientIP(r), "-", r.URL.Path)
return return
} }
@ -80,7 +82,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid URL", http.StatusBadRequest) http.Error(w, "Invalid URL", http.StatusBadRequest)
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Parse URL error: %v", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Parse URL error: %v",
r.Method, http.StatusBadRequest, time.Since(startTime), r.Method, http.StatusBadRequest, time.Since(startTime),
utils.GetClientIP(r), "-", actualURL, err) iputil.GetClientIP(r), "-", actualURL, err)
return return
} }
@ -98,7 +100,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Error creating request", http.StatusInternalServerError) http.Error(w, "Error creating request", http.StatusInternalServerError)
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error creating request: %v", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error creating request: %v",
r.Method, http.StatusInternalServerError, time.Since(startTime), r.Method, http.StatusInternalServerError, time.Since(startTime),
utils.GetClientIP(r), "-", actualURL, err) iputil.GetClientIP(r), "-", actualURL, err)
return return
} }
@ -131,7 +133,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
http.ServeFile(w, r, item.FilePath) http.ServeFile(w, r, item.FilePath)
collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(startTime), item.Size, utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(startTime), item.Size, iputil.GetClientIP(r), r)
return return
} }
} }
@ -142,7 +144,7 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Error forwarding request", http.StatusBadGateway) http.Error(w, "Error forwarding request", http.StatusBadGateway)
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error forwarding request: %v", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | Error forwarding request: %v",
r.Method, http.StatusBadGateway, time.Since(startTime), r.Method, http.StatusBadGateway, time.Since(startTime),
utils.GetClientIP(r), "-", actualURL, err) iputil.GetClientIP(r), "-", actualURL, err)
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -183,9 +185,9 @@ func (h *MirrorProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 记录访问日志 // 记录访问日志
log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s", log.Printf("| %-6s | %3d | %12s | %15s | %10s | %-30s | %s",
r.Method, resp.StatusCode, time.Since(startTime), r.Method, resp.StatusCode, time.Since(startTime),
utils.GetClientIP(r), utils.FormatBytes(written), iputil.GetClientIP(r), utils.FormatBytes(written),
utils.GetRequestSource(r), actualURL) utils.GetRequestSource(r), actualURL)
// 记录统计信息 // 记录统计信息
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(startTime), written, iputil.GetClientIP(r), r)
} }

View File

@ -17,6 +17,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/woodchen-ink/go-web-utils/iputil"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@ -232,7 +233,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "Welcome to CZL proxy.") fmt.Fprint(w, "Welcome to CZL proxy.")
log.Printf("[Proxy] %s %s -> %d (%s) from %s", r.Method, r.URL.Path, http.StatusOK, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Proxy] %s %s -> %d (%s) from %s", r.Method, r.URL.Path, http.StatusOK, iputil.GetClientIP(r), utils.GetRequestSource(r))
return return
} }
@ -258,7 +259,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 检查是否需要进行302跳转 // 检查是否需要进行302跳转
if h.redirectHandler != nil && h.redirectHandler.HandleRedirect(w, r, pathConfig, decodedPath, h.client) { if h.redirectHandler != nil && h.redirectHandler.HandleRedirect(w, r, pathConfig, decodedPath, h.client) {
// 如果进行了302跳转直接返回不继续处理 // 如果进行了302跳转直接返回不继续处理
collector.RecordRequest(r.URL.Path, http.StatusFound, time.Since(start), 0, utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, http.StatusFound, time.Since(start), 0, iputil.GetClientIP(r), r)
return return
} }
@ -342,7 +343,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// 设置最小必要的代理头部 // 设置最小必要的代理头部
clientIP := utils.GetClientIP(r) clientIP := iputil.GetClientIP(r)
proxyReq.Header.Set("X-Real-IP", clientIP) proxyReq.Header.Set("X-Real-IP", clientIP)
// 添加或更新 X-Forwarded-For - 减少重复获取客户端IP // 添加或更新 X-Forwarded-For - 减少重复获取客户端IP
@ -389,7 +390,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
http.ServeFile(w, r, item.FilePath) http.ServeFile(w, r, item.FilePath)
collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(start), item.Size, utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, http.StatusOK, time.Since(start), item.Size, iputil.GetClientIP(r), r)
return return
} }
} }
@ -399,10 +400,10 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
if ctx.Err() == context.DeadlineExceeded { if ctx.Err() == context.DeadlineExceeded {
h.errorHandler(w, r, fmt.Errorf("request timeout after %v", proxyRespTimeout)) h.errorHandler(w, r, fmt.Errorf("request timeout after %v", proxyRespTimeout))
log.Printf("[Proxy] ERR %s %s -> 408 (%s) timeout from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Proxy] ERR %s %s -> 408 (%s) timeout from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
} else { } else {
h.errorHandler(w, r, fmt.Errorf("proxy error: %v", err)) h.errorHandler(w, r, fmt.Errorf("proxy error: %v", err))
log.Printf("[Proxy] ERR %s %s -> 502 (%s) proxy error from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Proxy] ERR %s %s -> 502 (%s) proxy error from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
} }
return return
} }
@ -450,7 +451,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
written, err = io.CopyBuffer(w, resp.Body, buf) written, err = io.CopyBuffer(w, resp.Body, buf)
if err != nil && !isConnectionClosed(err) { if err != nil && !isConnectionClosed(err) {
log.Printf("[Proxy] ERR %s %s -> write error (%s) from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Proxy] ERR %s %s -> write error (%s) from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
return return
} }
} }
@ -461,13 +462,13 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
written, err = io.CopyBuffer(w, resp.Body, buf) written, err = io.CopyBuffer(w, resp.Body, buf)
if err != nil && !isConnectionClosed(err) { if err != nil && !isConnectionClosed(err) {
log.Printf("[Proxy] ERR %s %s -> write error (%s) from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Proxy] ERR %s %s -> write error (%s) from %s", r.Method, r.URL.Path, iputil.GetClientIP(r), utils.GetRequestSource(r))
return return
} }
} }
// 记录统计信息 // 记录统计信息
collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, utils.GetClientIP(r), r) collector.RecordRequest(r.URL.Path, resp.StatusCode, time.Since(start), written, iputil.GetClientIP(r), r)
} }
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {

View File

@ -8,6 +8,8 @@ import (
"proxy-go/internal/service" "proxy-go/internal/service"
"proxy-go/internal/utils" "proxy-go/internal/utils"
"strings" "strings"
"github.com/woodchen-ink/go-web-utils/iputil"
) )
// RedirectHandler 处理302跳转逻辑 // RedirectHandler 处理302跳转逻辑
@ -99,7 +101,7 @@ func (rh *RedirectHandler) performRedirect(w http.ResponseWriter, r *http.Reques
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
// 记录跳转日志 // 记录跳转日志
clientIP := utils.GetClientIP(r) clientIP := iputil.GetClientIP(r)
log.Printf("[Redirect] %s %s -> 302 %s (%s) from %s", log.Printf("[Redirect] %s %s -> 302 %s (%s) from %s",
r.Method, r.URL.Path, targetURL, clientIP, utils.GetRequestSource(r)) r.Method, r.URL.Path, targetURL, clientIP, utils.GetRequestSource(r))
} }

View File

@ -0,0 +1,130 @@
package handler
import (
"encoding/json"
"net/http"
"proxy-go/internal/security"
"time"
"github.com/woodchen-ink/go-web-utils/iputil"
)
// SecurityHandler 安全管理处理器
type SecurityHandler struct {
banManager *security.IPBanManager
}
// NewSecurityHandler 创建安全管理处理器
func NewSecurityHandler(banManager *security.IPBanManager) *SecurityHandler {
return &SecurityHandler{
banManager: banManager,
}
}
// GetBannedIPs 获取被封禁的IP列表
func (sh *SecurityHandler) GetBannedIPs(w http.ResponseWriter, r *http.Request) {
if sh.banManager == nil {
http.Error(w, "Security manager not enabled", http.StatusServiceUnavailable)
return
}
bannedIPs := sh.banManager.GetBannedIPs()
// 转换为前端友好的格式
result := make([]map[string]interface{}, 0, len(bannedIPs))
for ip, banEndTime := range bannedIPs {
result = append(result, map[string]interface{}{
"ip": ip,
"ban_end_time": banEndTime.Format("2006-01-02 15:04:05"),
"remaining_seconds": int64(time.Until(banEndTime).Seconds()),
})
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"banned_ips": result,
"count": len(result),
})
}
// UnbanIP 手动解封IP
func (sh *SecurityHandler) UnbanIP(w http.ResponseWriter, r *http.Request) {
if sh.banManager == nil {
http.Error(w, "Security manager not enabled", http.StatusServiceUnavailable)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" {
http.Error(w, "IP address is required", http.StatusBadRequest)
return
}
success := sh.banManager.UnbanIP(req.IP)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"success": success,
"message": func() string {
if success {
return "IP解封成功"
}
return "IP未在封禁列表中"
}(),
})
}
// GetSecurityStats 获取安全统计信息
func (sh *SecurityHandler) GetSecurityStats(w http.ResponseWriter, r *http.Request) {
if sh.banManager == nil {
http.Error(w, "Security manager not enabled", http.StatusServiceUnavailable)
return
}
stats := sh.banManager.GetStats()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// CheckIPStatus 检查IP状态
func (sh *SecurityHandler) CheckIPStatus(w http.ResponseWriter, r *http.Request) {
if sh.banManager == nil {
http.Error(w, "Security manager not enabled", http.StatusServiceUnavailable)
return
}
ip := r.URL.Query().Get("ip")
if ip == "" {
// 如果没有指定IP使用请求的IP
ip = iputil.GetClientIP(r)
}
banned, banEndTime := sh.banManager.GetBanInfo(ip)
result := map[string]interface{}{
"ip": ip,
"banned": banned,
}
if banned {
result["ban_end_time"] = banEndTime.Format("2006-01-02 15:04:05")
result["remaining_seconds"] = int64(time.Until(banEndTime).Seconds())
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}

View File

@ -0,0 +1,86 @@
package middleware
import (
"fmt"
"net/http"
"proxy-go/internal/security"
"time"
"github.com/woodchen-ink/go-web-utils/iputil"
)
// SecurityMiddleware 安全中间件
type SecurityMiddleware struct {
banManager *security.IPBanManager
}
// NewSecurityMiddleware 创建安全中间件
func NewSecurityMiddleware(banManager *security.IPBanManager) *SecurityMiddleware {
return &SecurityMiddleware{
banManager: banManager,
}
}
// IPBanMiddleware IP封禁中间件
func (sm *SecurityMiddleware) IPBanMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := iputil.GetClientIP(r)
// 检查IP是否被封禁
if sm.banManager.IsIPBanned(clientIP) {
banned, banEndTime := sm.banManager.GetBanInfo(clientIP)
if banned {
// 返回429状态码和封禁信息
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", time.Until(banEndTime).Seconds()))
w.WriteHeader(http.StatusTooManyRequests)
remainingTime := time.Until(banEndTime)
response := fmt.Sprintf(`{
"error": "IP temporarily banned due to excessive 404 errors",
"message": "您的IP因频繁访问不存在的资源而被暂时封禁",
"ban_end_time": "%s",
"remaining_seconds": %.0f
}`, banEndTime.Format("2006-01-02 15:04:05"), remainingTime.Seconds())
w.Write([]byte(response))
return
}
}
// 创建响应写入器包装器来捕获状态码
wrapper := &responseWrapper{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// 继续处理请求
next.ServeHTTP(wrapper, r)
// 如果响应是404记录错误
if wrapper.statusCode == http.StatusNotFound {
sm.banManager.RecordError(clientIP)
}
})
}
// responseWrapper 响应包装器,用于捕获状态码
type responseWrapper struct {
http.ResponseWriter
statusCode int
}
// WriteHeader 重写WriteHeader方法来捕获状态码
func (rw *responseWrapper) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// Write 重写Write方法确保状态码被正确设置
func (rw *responseWrapper) Write(b []byte) (int, error) {
// 如果还没有设置状态码默认为200
if rw.statusCode == 0 {
rw.statusCode = http.StatusOK
}
return rw.ResponseWriter.Write(b)
}

View File

@ -0,0 +1,278 @@
package security
import (
"log"
"sync"
"time"
)
// IPBanManager IP封禁管理器
type IPBanManager struct {
// 404错误计数器 map[ip]count
errorCounts sync.Map
// IP封禁列表 map[ip]banEndTime
bannedIPs sync.Map
// 配置参数
config *IPBanConfig
// 清理任务停止信号
stopCleanup chan struct{}
// 清理任务等待组
cleanupWG sync.WaitGroup
}
// IPBanConfig IP封禁配置
type IPBanConfig struct {
// 404错误阈值超过此数量将被封禁
ErrorThreshold int `json:"error_threshold"`
// 统计窗口时间(分钟)
WindowMinutes int `json:"window_minutes"`
// 封禁时长(分钟)
BanDurationMinutes int `json:"ban_duration_minutes"`
// 清理间隔(分钟)
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
}
// errorRecord 错误记录
type errorRecord struct {
count int
firstTime time.Time
lastTime time.Time
}
// DefaultIPBanConfig 默认配置
func DefaultIPBanConfig() *IPBanConfig {
return &IPBanConfig{
ErrorThreshold: 10, // 10次404错误
WindowMinutes: 5, // 5分钟内
BanDurationMinutes: 5, // 封禁5分钟
CleanupIntervalMinutes: 1, // 每分钟清理一次
}
}
// NewIPBanManager 创建IP封禁管理器
func NewIPBanManager(config *IPBanConfig) *IPBanManager {
if config == nil {
config = DefaultIPBanConfig()
}
manager := &IPBanManager{
config: config,
stopCleanup: make(chan struct{}),
}
// 启动清理任务
manager.startCleanupTask()
log.Printf("[Security] IP封禁管理器已启动 - 阈值: %d次/%.0f分钟, 封禁时长: %.0f分钟",
config.ErrorThreshold,
float64(config.WindowMinutes),
float64(config.BanDurationMinutes))
return manager
}
// RecordError 记录404错误
func (m *IPBanManager) RecordError(ip string) {
now := time.Now()
windowStart := now.Add(-time.Duration(m.config.WindowMinutes) * time.Minute)
// 加载或创建错误记录
value, _ := m.errorCounts.LoadOrStore(ip, &errorRecord{
count: 0,
firstTime: now,
lastTime: now,
})
record := value.(*errorRecord)
// 如果第一次记录时间超出窗口,重置计数
if record.firstTime.Before(windowStart) {
record.count = 1
record.firstTime = now
record.lastTime = now
} else {
record.count++
record.lastTime = now
}
// 检查是否需要封禁
if record.count >= m.config.ErrorThreshold {
m.banIP(ip, now)
// 重置计数器,避免重复封禁
record.count = 0
record.firstTime = now
}
log.Printf("[Security] 记录404错误 IP: %s, 当前计数: %d/%d (窗口: %.0f分钟)",
ip, record.count, m.config.ErrorThreshold, float64(m.config.WindowMinutes))
}
// banIP 封禁IP
func (m *IPBanManager) banIP(ip string, banTime time.Time) {
banEndTime := banTime.Add(time.Duration(m.config.BanDurationMinutes) * time.Minute)
m.bannedIPs.Store(ip, banEndTime)
log.Printf("[Security] IP已被封禁: %s, 封禁至: %s (%.0f分钟)",
ip, banEndTime.Format("15:04:05"), float64(m.config.BanDurationMinutes))
}
// IsIPBanned 检查IP是否被封禁
func (m *IPBanManager) IsIPBanned(ip string) bool {
value, exists := m.bannedIPs.Load(ip)
if !exists {
return false
}
banEndTime := value.(time.Time)
now := time.Now()
// 检查封禁是否已过期
if now.After(banEndTime) {
m.bannedIPs.Delete(ip)
log.Printf("[Security] IP封禁已过期自动解封: %s", ip)
return false
}
return true
}
// GetBanInfo 获取IP封禁信息
func (m *IPBanManager) GetBanInfo(ip string) (bool, time.Time) {
value, exists := m.bannedIPs.Load(ip)
if !exists {
return false, time.Time{}
}
banEndTime := value.(time.Time)
now := time.Now()
if now.After(banEndTime) {
m.bannedIPs.Delete(ip)
return false, time.Time{}
}
return true, banEndTime
}
// UnbanIP 手动解封IP
func (m *IPBanManager) UnbanIP(ip string) bool {
_, exists := m.bannedIPs.Load(ip)
if exists {
m.bannedIPs.Delete(ip)
log.Printf("[Security] 手动解封IP: %s", ip)
return true
}
return false
}
// GetBannedIPs 获取所有被封禁的IP列表
func (m *IPBanManager) GetBannedIPs() map[string]time.Time {
result := make(map[string]time.Time)
now := time.Now()
m.bannedIPs.Range(func(key, value interface{}) bool {
ip := key.(string)
banEndTime := value.(time.Time)
// 清理过期的封禁
if now.After(banEndTime) {
m.bannedIPs.Delete(ip)
} else {
result[ip] = banEndTime
}
return true
})
return result
}
// GetStats 获取统计信息
func (m *IPBanManager) GetStats() map[string]interface{} {
bannedCount := 0
errorRecordCount := 0
m.bannedIPs.Range(func(key, value interface{}) bool {
bannedCount++
return true
})
m.errorCounts.Range(func(key, value interface{}) bool {
errorRecordCount++
return true
})
return map[string]interface{}{
"banned_ips_count": bannedCount,
"error_records_count": errorRecordCount,
"config": m.config,
}
}
// startCleanupTask 启动清理任务
func (m *IPBanManager) startCleanupTask() {
m.cleanupWG.Add(1)
go func() {
defer m.cleanupWG.Done()
ticker := time.NewTicker(time.Duration(m.config.CleanupIntervalMinutes) * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup()
case <-m.stopCleanup:
return
}
}
}()
}
// cleanup 清理过期数据
func (m *IPBanManager) cleanup() {
now := time.Now()
windowStart := now.Add(-time.Duration(m.config.WindowMinutes) * time.Minute)
// 清理过期的错误记录
var expiredIPs []string
m.errorCounts.Range(func(key, value interface{}) bool {
ip := key.(string)
record := value.(*errorRecord)
// 如果最后一次错误时间超出窗口,删除记录
if record.lastTime.Before(windowStart) {
expiredIPs = append(expiredIPs, ip)
}
return true
})
for _, ip := range expiredIPs {
m.errorCounts.Delete(ip)
}
// 清理过期的封禁记录
var expiredBans []string
m.bannedIPs.Range(func(key, value interface{}) bool {
ip := key.(string)
banEndTime := value.(time.Time)
if now.After(banEndTime) {
expiredBans = append(expiredBans, ip)
}
return true
})
for _, ip := range expiredBans {
m.bannedIPs.Delete(ip)
}
if len(expiredIPs) > 0 || len(expiredBans) > 0 {
log.Printf("[Security] 清理任务完成 - 清理错误记录: %d, 清理过期封禁: %d",
len(expiredIPs), len(expiredBans))
}
}
// Stop 停止IP封禁管理器
func (m *IPBanManager) Stop() {
close(m.stopCleanup)
m.cleanupWG.Wait()
log.Printf("[Security] IP封禁管理器已停止")
}

View File

@ -6,7 +6,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
neturl "net/url" neturl "net/url"
"path/filepath" "path/filepath"
@ -93,19 +92,6 @@ func GenerateRequestID() string {
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
func GetClientIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return ip
}
return r.RemoteAddr
}
// 获取请求来源 // 获取请求来源
func GetRequestSource(r *http.Request) string { func GetRequestSource(r *http.Request) string {
referer := r.Header.Get("Referer") referer := r.Header.Get("Referer")

48
main.go
View File

@ -13,6 +13,7 @@ import (
"proxy-go/internal/initapp" "proxy-go/internal/initapp"
"proxy-go/internal/metrics" "proxy-go/internal/metrics"
"proxy-go/internal/middleware" "proxy-go/internal/middleware"
"proxy-go/internal/security"
"strings" "strings"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@ -55,6 +56,20 @@ func main() {
}) })
compManagerAtomic.Store(compManager) compManagerAtomic.Store(compManager)
// 创建安全管理器
var banManager *security.IPBanManager
var securityMiddleware *middleware.SecurityMiddleware
if cfg.Security.IPBan.Enabled {
banConfig := &security.IPBanConfig{
ErrorThreshold: cfg.Security.IPBan.ErrorThreshold,
WindowMinutes: cfg.Security.IPBan.WindowMinutes,
BanDurationMinutes: cfg.Security.IPBan.BanDurationMinutes,
CleanupIntervalMinutes: cfg.Security.IPBan.CleanupIntervalMinutes,
}
banManager = security.NewIPBanManager(banConfig)
securityMiddleware = middleware.NewSecurityMiddleware(banManager)
}
// 创建代理处理器 // 创建代理处理器
mirrorHandler := handler.NewMirrorProxyHandler() mirrorHandler := handler.NewMirrorProxyHandler()
proxyHandler := handler.NewProxyHandler(cfg) proxyHandler := handler.NewProxyHandler(cfg)
@ -62,6 +77,12 @@ func main() {
// 创建配置处理器 // 创建配置处理器
configHandler := handler.NewConfigHandler(configManager) configHandler := handler.NewConfigHandler(configManager)
// 创建安全管理处理器
var securityHandler *handler.SecurityHandler
if banManager != nil {
securityHandler = handler.NewSecurityHandler(banManager)
}
// 注册压缩配置更新回调 // 注册压缩配置更新回调
config.RegisterUpdateCallback(func(newCfg *config.Config) { config.RegisterUpdateCallback(func(newCfg *config.Config) {
// 更新压缩管理器 // 更新压缩管理器
@ -92,6 +113,17 @@ func main() {
{http.MethodPost, "/admin/api/cache/config", handler.NewCacheAdminHandler(proxyHandler.Cache, mirrorHandler.Cache).UpdateCacheConfig, true}, {http.MethodPost, "/admin/api/cache/config", handler.NewCacheAdminHandler(proxyHandler.Cache, mirrorHandler.Cache).UpdateCacheConfig, true},
} }
// 添加安全API路由如果启用了安全功能
if securityHandler != nil {
securityRoutes := []Route{
{http.MethodGet, "/admin/api/security/banned-ips", securityHandler.GetBannedIPs, true},
{http.MethodPost, "/admin/api/security/unban", securityHandler.UnbanIP, true},
{http.MethodGet, "/admin/api/security/stats", securityHandler.GetSecurityStats, true},
{http.MethodGet, "/admin/api/security/check-ip", securityHandler.CheckIPStatus, true},
}
apiRoutes = append(apiRoutes, securityRoutes...)
}
// 创建路由处理器 // 创建路由处理器
handlers := []struct { handlers := []struct {
matcher func(*http.Request) bool matcher func(*http.Request) bool
@ -165,13 +197,20 @@ func main() {
http.NotFound(w, r) http.NotFound(w, r)
}) })
// 添加压缩中间件(使用动态压缩管理器) // 构建中间件链
var handler http.Handler = mainHandler var handler http.Handler = mainHandler
// 添加安全中间件(最外层,优先级最高)
if securityMiddleware != nil {
handler = securityMiddleware.IPBanMiddleware(handler)
}
// 添加压缩中间件
if cfg.Compression.Gzip.Enabled || cfg.Compression.Brotli.Enabled { if cfg.Compression.Gzip.Enabled || cfg.Compression.Brotli.Enabled {
// 创建动态压缩中间件包装器 // 创建动态压缩中间件包装器
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
currentCompManager := compManagerAtomic.Load().(compression.Manager) currentCompManager := compManagerAtomic.Load().(compression.Manager)
middleware.CompressionMiddleware(currentCompManager)(mainHandler).ServeHTTP(w, r) middleware.CompressionMiddleware(currentCompManager)(handler).ServeHTTP(w, r)
}) })
} }
@ -188,6 +227,11 @@ func main() {
<-sigChan <-sigChan
log.Println("Shutting down server...") log.Println("Shutting down server...")
// 停止安全管理器
if banManager != nil {
banManager.Stop()
}
// 停止指标存储服务 // 停止指标存储服务
metrics.StopMetricsStorage() metrics.StopMetricsStorage()

View File

@ -17,7 +17,8 @@ import {
} from "@/components/ui/dialog" } from "@/components/ui/dialog"
import { Switch } from "@/components/ui/switch" import { Switch } from "@/components/ui/switch"
import { Slider } from "@/components/ui/slider" import { Slider } from "@/components/ui/slider"
import { Plus, Trash2, Edit, Download, Upload } from "lucide-react" import { Plus, Trash2, Edit, Download, Upload, Shield } from "lucide-react"
import Link from "next/link"
import { import {
AlertDialog, AlertDialog,
AlertDialogAction, AlertDialogAction,
@ -51,12 +52,23 @@ interface CompressionConfig {
Level: number Level: number
} }
interface SecurityConfig {
IPBan: {
Enabled: boolean
ErrorThreshold: number
WindowMinutes: number
BanDurationMinutes: number
CleanupIntervalMinutes: number
}
}
interface Config { interface Config {
MAP: Record<string, PathMapping | string> MAP: Record<string, PathMapping | string>
Compression: { Compression: {
Gzip: CompressionConfig Gzip: CompressionConfig
Brotli: CompressionConfig Brotli: CompressionConfig
} }
Security: SecurityConfig
} }
export default function ConfigPage() { export default function ConfigPage() {
@ -163,6 +175,20 @@ export default function ConfigPage() {
} }
const data = await response.json() const data = await response.json()
// 确保安全配置存在
if (!data.Security) {
data.Security = {
IPBan: {
Enabled: false,
ErrorThreshold: 10,
WindowMinutes: 5,
BanDurationMinutes: 5,
CleanupIntervalMinutes: 1
}
}
}
isConfigFromApiRef.current = true // 标记配置来自API isConfigFromApiRef.current = true // 标记配置来自API
setConfig(data) setConfig(data)
} catch (error) { } catch (error) {
@ -374,6 +400,31 @@ export default function ConfigPage() {
updateConfig(newConfig) updateConfig(newConfig)
} }
const updateSecurity = (field: keyof SecurityConfig['IPBan'], value: boolean | number) => {
if (!config) return
const newConfig = { ...config }
// 确保安全配置存在
if (!newConfig.Security) {
newConfig.Security = {
IPBan: {
Enabled: false,
ErrorThreshold: 10,
WindowMinutes: 5,
BanDurationMinutes: 5,
CleanupIntervalMinutes: 1
}
}
}
if (field === 'Enabled') {
newConfig.Security.IPBan.Enabled = value as boolean
} else {
newConfig.Security.IPBan[field] = value as number
}
updateConfig(newConfig)
}
const handleExtensionMapEdit = (path: string) => { const handleExtensionMapEdit = (path: string) => {
// 将添加规则的操作重定向到handleExtensionRuleEdit // 将添加规则的操作重定向到handleExtensionRuleEdit
handleExtensionRuleEdit(path); handleExtensionRuleEdit(path);
@ -433,6 +484,19 @@ export default function ConfigPage() {
throw new Error('配置文件压缩设置格式不正确') throw new Error('配置文件压缩设置格式不正确')
} }
// 如果没有安全配置,添加默认配置
if (!newConfig.Security) {
newConfig.Security = {
IPBan: {
Enabled: false,
ErrorThreshold: 10,
WindowMinutes: 5,
BanDurationMinutes: 5,
CleanupIntervalMinutes: 1
}
}
}
// 验证路径映射 // 验证路径映射
for (const [path, target] of Object.entries(newConfig.MAP)) { for (const [path, target] of Object.entries(newConfig.MAP)) {
if (!path.startsWith('/')) { if (!path.startsWith('/')) {
@ -785,6 +849,7 @@ export default function ConfigPage() {
<TabsList> <TabsList>
<TabsTrigger value="paths"></TabsTrigger> <TabsTrigger value="paths"></TabsTrigger>
<TabsTrigger value="compression"></TabsTrigger> <TabsTrigger value="compression"></TabsTrigger>
<TabsTrigger value="security"></TabsTrigger>
</TabsList> </TabsList>
<TabsContent value="paths" className="space-y-4"> <TabsContent value="paths" className="space-y-4">
@ -1015,6 +1080,94 @@ export default function ConfigPage() {
</CardContent> </CardContent>
</Card> </Card>
</TabsContent> </TabsContent>
<TabsContent value="security" className="space-y-6">
<Card>
<CardHeader className="flex flex-row items-center justify-between">
<CardTitle>IP </CardTitle>
<Button variant="outline" asChild>
<Link href="/dashboard/security">
<Shield className="w-4 h-4 mr-2" />
</Link>
</Button>
</CardHeader>
<CardContent className="space-y-4">
<div className="flex items-center justify-between">
<div>
<Label> IP </Label>
<p className="text-sm text-muted-foreground">
IP 访
</p>
</div>
<Switch
checked={config?.Security?.IPBan?.Enabled || false}
onCheckedChange={(checked) => updateSecurity('Enabled', checked)}
/>
</div>
{config?.Security?.IPBan?.Enabled && (
<>
<div className="space-y-2">
<Label>404 </Label>
<Input
type="number"
min={1}
max={100}
value={config?.Security?.IPBan?.ErrorThreshold || 10}
onChange={(e) => updateSecurity('ErrorThreshold', parseInt(e.target.value) || 10)}
/>
<p className="text-sm text-muted-foreground">
IP 访
</p>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="number"
min={1}
max={60}
value={config?.Security?.IPBan?.WindowMinutes || 5}
onChange={(e) => updateSecurity('WindowMinutes', parseInt(e.target.value) || 5)}
/>
<p className="text-sm text-muted-foreground">
404
</p>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="number"
min={1}
max={1440}
value={config?.Security?.IPBan?.BanDurationMinutes || 5}
onChange={(e) => updateSecurity('BanDurationMinutes', parseInt(e.target.value) || 5)}
/>
<p className="text-sm text-muted-foreground">
IP
</p>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="number"
min={1}
max={60}
value={config?.Security?.IPBan?.CleanupIntervalMinutes || 1}
onChange={(e) => updateSecurity('CleanupIntervalMinutes', parseInt(e.target.value) || 1)}
/>
<p className="text-sm text-muted-foreground">
</p>
</div>
</>
)}
</CardContent>
</Card>
</TabsContent>
</Tabs> </Tabs>
</CardContent> </CardContent>
</Card> </Card>

View File

@ -0,0 +1,386 @@
"use client"
import React, { useEffect, useState, useCallback } from "react"
import { Button } from "@/components/ui/button"
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"
import { useToast } from "@/components/ui/use-toast"
import { useRouter } from "next/navigation"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table"
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog"
import { Shield, Ban, Clock, Trash2, RefreshCw } from "lucide-react"
interface BannedIP {
ip: string
ban_end_time: string
remaining_seconds: number
}
interface SecurityStats {
banned_ips_count: number
error_records_count: number
config: {
ErrorThreshold: number
WindowMinutes: number
BanDurationMinutes: number
CleanupIntervalMinutes: number
}
}
interface IPStatus {
ip: string
banned: boolean
ban_end_time?: string
remaining_seconds?: number
}
export default function SecurityPage() {
const [bannedIPs, setBannedIPs] = useState<BannedIP[]>([])
const [stats, setStats] = useState<SecurityStats | null>(null)
const [loading, setLoading] = useState(true)
const [refreshing, setRefreshing] = useState(false)
const [checkingIP, setCheckingIP] = useState("")
const [ipStatus, setIPStatus] = useState<IPStatus | null>(null)
const [unbanning, setUnbanning] = useState<string | null>(null)
const { toast } = useToast()
const router = useRouter()
const fetchData = useCallback(async () => {
try {
const token = localStorage.getItem("token")
if (!token) {
router.push("/login")
return
}
const [bannedResponse, statsResponse] = await Promise.all([
fetch("/admin/api/security/banned-ips", {
headers: { 'Authorization': `Bearer ${token}` }
}),
fetch("/admin/api/security/stats", {
headers: { 'Authorization': `Bearer ${token}` }
})
])
if (bannedResponse.status === 401 || statsResponse.status === 401) {
localStorage.removeItem("token")
router.push("/login")
return
}
if (bannedResponse.ok) {
const bannedData = await bannedResponse.json()
setBannedIPs(bannedData.banned_ips || [])
}
if (statsResponse.ok) {
const statsData = await statsResponse.json()
setStats(statsData)
}
} catch (error) {
console.error("获取安全数据失败:", error)
toast({
title: "错误",
description: "获取安全数据失败",
variant: "destructive",
})
} finally {
setLoading(false)
setRefreshing(false)
}
}, [router, toast])
useEffect(() => {
fetchData()
// 每30秒自动刷新一次数据
const interval = setInterval(fetchData, 30000)
return () => clearInterval(interval)
}, [fetchData])
const handleRefresh = () => {
setRefreshing(true)
fetchData()
}
const checkIPStatus = async () => {
if (!checkingIP.trim()) return
try {
const token = localStorage.getItem("token")
if (!token) {
router.push("/login")
return
}
const response = await fetch(`/admin/api/security/check-ip?ip=${encodeURIComponent(checkingIP)}`, {
headers: { 'Authorization': `Bearer ${token}` }
})
if (response.status === 401) {
localStorage.removeItem("token")
router.push("/login")
return
}
if (response.ok) {
const data = await response.json()
setIPStatus(data)
} else {
throw new Error("检查IP状态失败")
}
} catch {
toast({
title: "错误",
description: "检查IP状态失败",
variant: "destructive",
})
}
}
const unbanIP = async (ip: string) => {
try {
const token = localStorage.getItem("token")
if (!token) {
router.push("/login")
return
}
const response = await fetch("/admin/api/security/unban", {
method: "POST",
headers: {
'Authorization': `Bearer ${token}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({ ip })
})
if (response.status === 401) {
localStorage.removeItem("token")
router.push("/login")
return
}
if (response.ok) {
const data = await response.json()
if (data.success) {
toast({
title: "成功",
description: `IP ${ip} 已解封`,
})
fetchData() // 刷新数据
} else {
toast({
title: "提示",
description: data.message,
})
}
} else {
throw new Error("解封IP失败")
}
} catch {
toast({
title: "错误",
description: "解封IP失败",
variant: "destructive",
})
} finally {
setUnbanning(null)
}
}
const formatTime = (seconds: number) => {
if (seconds <= 0) return "已过期"
const minutes = Math.floor(seconds / 60)
const remainingSeconds = seconds % 60
if (minutes > 0) {
return `${minutes}${remainingSeconds}`
}
return `${remainingSeconds}`
}
if (loading) {
return (
<div className="flex h-[calc(100vh-4rem)] items-center justify-center">
<div className="text-center">
<div className="text-lg font-medium">...</div>
<div className="text-sm text-gray-500 mt-1"></div>
</div>
</div>
)
}
return (
<div className="space-y-6">
<Card>
<CardHeader className="flex flex-row items-center justify-between">
<CardTitle className="flex items-center gap-2">
<Shield className="w-5 h-5" />
</CardTitle>
<Button onClick={handleRefresh} disabled={refreshing} variant="outline">
<RefreshCw className={`w-4 h-4 mr-2 ${refreshing ? 'animate-spin' : ''}`} />
</Button>
</CardHeader>
<CardContent>
{stats && (
<div className="grid grid-cols-1 md:grid-cols-4 gap-4 mb-6">
<div className="bg-red-50 p-4 rounded-lg">
<div className="flex items-center gap-2">
<Ban className="w-5 h-5 text-red-600" />
<div>
<div className="text-2xl font-bold text-red-600">{stats.banned_ips_count}</div>
<div className="text-sm text-red-600">IP</div>
</div>
</div>
</div>
<div className="bg-yellow-50 p-4 rounded-lg">
<div className="flex items-center gap-2">
<Clock className="w-5 h-5 text-yellow-600" />
<div>
<div className="text-2xl font-bold text-yellow-600">{stats.error_records_count}</div>
<div className="text-sm text-yellow-600"></div>
</div>
</div>
</div>
<div className="bg-blue-50 p-4 rounded-lg">
<div className="text-sm text-blue-600 mb-1"></div>
<div className="text-lg font-bold text-blue-600">
{stats.config.ErrorThreshold}/{stats.config.WindowMinutes}
</div>
</div>
<div className="bg-green-50 p-4 rounded-lg">
<div className="text-sm text-green-600 mb-1"></div>
<div className="text-lg font-bold text-green-600">
{stats.config.BanDurationMinutes}
</div>
</div>
</div>
)}
<div className="space-y-4">
<div className="flex gap-4">
<div className="flex-1">
<Label>IP状态</Label>
<div className="flex gap-2 mt-1">
<Input
placeholder="输入IP地址"
value={checkingIP}
onChange={(e) => setCheckingIP(e.target.value)}
/>
<Button onClick={checkIPStatus}></Button>
</div>
</div>
</div>
{ipStatus && (
<Card>
<CardContent className="pt-4">
<div className="flex items-center gap-4">
<div>
<strong>IP: {ipStatus.ip}</strong>
</div>
<div className={`px-2 py-1 rounded text-sm ${
ipStatus.banned
? 'bg-red-100 text-red-800'
: 'bg-green-100 text-green-800'
}`}>
{ipStatus.banned ? '已封禁' : '正常'}
</div>
{ipStatus.banned && ipStatus.remaining_seconds && ipStatus.remaining_seconds > 0 && (
<div className="text-sm text-muted-foreground">
: {formatTime(ipStatus.remaining_seconds)}
</div>
)}
</div>
</CardContent>
</Card>
)}
</div>
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle>IP列表</CardTitle>
</CardHeader>
<CardContent>
{bannedIPs.length === 0 ? (
<div className="text-center py-8 text-muted-foreground">
IP
</div>
) : (
<Table>
<TableHeader>
<TableRow>
<TableHead>IP地址</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{bannedIPs.map((bannedIP) => (
<TableRow key={bannedIP.ip}>
<TableCell className="font-mono">{bannedIP.ip}</TableCell>
<TableCell>{bannedIP.ban_end_time}</TableCell>
<TableCell>
<span className={bannedIP.remaining_seconds <= 0 ? 'text-muted-foreground' : 'text-orange-600'}>
{formatTime(bannedIP.remaining_seconds)}
</span>
</TableCell>
<TableCell>
<Button
variant="outline"
size="sm"
onClick={() => setUnbanning(bannedIP.ip)}
disabled={bannedIP.remaining_seconds <= 0}
>
<Trash2 className="w-4 h-4 mr-1" />
</Button>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)}
</CardContent>
</Card>
<AlertDialog open={!!unbanning} onOpenChange={(open) => !open && setUnbanning(null)}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle></AlertDialogTitle>
<AlertDialogDescription>
IP地址 &ldquo;{unbanning}&rdquo;
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel></AlertDialogCancel>
<AlertDialogAction onClick={() => unbanning && unbanIP(unbanning)}>
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</div>
)
}

View File

@ -55,6 +55,12 @@ export function Nav() {
> >
</Link> </Link>
<Link
href="/dashboard/security"
className={pathname === "/dashboard/security" ? "text-primary" : "text-muted-foreground"}
>
</Link>
</div> </div>
<Button variant="ghost" onClick={handleLogout}> <Button variant="ghost" onClick={handleLogout}>
退 退