From 7f4a964163796e39cc473adb2cbef40024318eca Mon Sep 17 00:00:00 2001 From: wood chen Date: Wed, 12 Mar 2025 20:27:20 +0800 Subject: [PATCH] =?UTF-8?q?feat(auth):=20=E5=A2=9E=E5=BC=BAOAuth=E8=AE=A4?= =?UTF-8?q?=E8=AF=81=E7=8A=B6=E6=80=81=E7=AE=A1=E7=90=86=E5=92=8C=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 state 状态管理机制,增加 10 分钟有效期 - 实现 generateState 和 validateState 方法 - 优化 LoginHandler 和 OAuthCallbackHandler 中的状态验证逻辑 - 添加更详细的调试和错误日志记录 - 完善回调地址生成逻辑,支持更多网络环境 - 在 OAuth 授权请求中添加 scope 参数 --- internal/handler/auth.go | 107 ++++++++++++++++++++++++++++++++------- 1 file changed, 89 insertions(+), 18 deletions(-) diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 486cc4d..1b3f38f 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -18,6 +18,7 @@ import ( const ( tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天 + stateExpiry = 10 * time.Minute // State 过期时间为 10 分钟 ) type OAuthUserInfo struct { @@ -41,9 +42,10 @@ type OAuthUserInfo struct { } type OAuthToken struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` } type tokenInfo struct { @@ -52,6 +54,11 @@ type tokenInfo struct { username string } +type stateInfo struct { + createdAt time.Time + expiresAt time.Time +} + type authManager struct { tokens sync.Map states sync.Map @@ -60,6 +67,7 @@ type authManager struct { func newAuthManager() *authManager { am := &authManager{} go am.cleanExpiredTokens() + go am.cleanExpiredStates() return am } @@ -69,6 +77,27 @@ func (am *authManager) generateToken() string { return base64.URLEncoding.EncodeToString(b) } +func (am *authManager) generateState() string { + state := am.generateToken() + am.states.Store(state, stateInfo{ + createdAt: time.Now(), + expiresAt: time.Now().Add(stateExpiry), + }) + return state +} + +func (am *authManager) validateState(state string) bool { + if info, ok := am.states.Load(state); ok { + stateInfo := info.(stateInfo) + if time.Now().Before(stateInfo.expiresAt) { + am.states.Delete(state) // 使用后立即删除 + return true + } + am.states.Delete(state) // 过期也删除 + } + return false +} + func (am *authManager) addToken(token string, username string, expiry time.Duration) { am.tokens.Store(token, tokenInfo{ createdAt: time.Now(), @@ -102,6 +131,20 @@ func (am *authManager) cleanExpiredTokens() { } } +func (am *authManager) cleanExpiredStates() { + ticker := time.NewTicker(time.Minute) + for range ticker.C { + am.states.Range(func(key, value interface{}) bool { + state := key.(string) + info := value.(stateInfo) + if time.Now().After(info.expiresAt) { + am.states.Delete(state) + } + return true + }) + } +} + // CheckAuth 检查认证令牌是否有效 func (h *ProxyHandler) CheckAuth(token string) bool { return h.auth.validateToken(token) @@ -150,30 +193,48 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc { // getCallbackURL 从请求中获取回调地址 func getCallbackURL(r *http.Request) string { - if os.Getenv("OAUTH_REDIRECT_URI") != "" { - return os.Getenv("OAUTH_REDIRECT_URI") - } else { - scheme := "http" - if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { - scheme = "https" + if redirectURI := os.Getenv("OAUTH_REDIRECT_URI"); redirectURI != "" { + // 验证URI格式 + if _, err := url.Parse(redirectURI); err == nil { + log.Printf("[Auth] DEBUG Using configured OAUTH_REDIRECT_URI: %s", redirectURI) + return redirectURI } - return fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, r.Host) + log.Printf("[Auth] WARNING Invalid OAUTH_REDIRECT_URI format: %s", redirectURI) } + + // 更可靠地检测协议 + scheme := "http" + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + scheme = "https" + } + + // 考虑X-Forwarded-Host头 + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + + callbackURL := fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, host) + log.Printf("[Auth] DEBUG Generated callback URL: %s", callbackURL) + return callbackURL } // LoginHandler 处理登录请求,重定向到 OAuth 授权页面 func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { - state := h.auth.generateToken() - h.auth.states.Store(state, time.Now()) - + state := h.auth.generateState() clientID := os.Getenv("OAUTH_CLIENT_ID") redirectURI := getCallbackURL(r) + // 记录生成的state和重定向URI + log.Printf("[Auth] DEBUG %s %s -> Generated state=%s, redirect_uri=%s", + r.Method, r.URL.Path, state, redirectURI) + authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s", url.Values{ "response_type": {"code"}, "client_id": {clientID}, "redirect_uri": {redirectURI}, + "scope": {"read write"}, // 添加scope参数 "state": {state}, }.Encode()) @@ -185,13 +246,17 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") + // 记录完整请求信息 + log.Printf("[Auth] DEBUG %s %s -> Callback received with state=%s, code=%s, full URL: %s", + r.Method, r.URL.Path, state, code, r.URL.String()) + // 验证 state - if _, ok := h.auth.states.Load(state); !ok { - log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) + if !h.auth.validateState(state) { + log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s", + r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r)) http.Error(w, "Invalid state", http.StatusBadRequest) return } - h.auth.states.Delete(state) // 验证code参数 if code == "" { @@ -212,6 +277,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque return } + // 记录令牌交换请求信息 + log.Printf("[Auth] DEBUG %s %s -> Exchanging code for token with redirect_uri=%s", + r.Method, r.URL.Path, redirectURI) + resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token", url.Values{ "grant_type": {"authorization_code"}, @@ -229,8 +298,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque // 检查响应状态码 if resp.StatusCode != http.StatusOK { - log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error status: %s from %s", - r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, utils.GetRequestSource(r)) + // 读取错误响应内容 + bodyBytes, _ := io.ReadAll(resp.Body) + log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s", + r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes)) http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError) return }