feat(auth): 增强OAuth认证状态管理和安全性

- 新增 state 状态管理机制,增加 10 分钟有效期
- 实现 generateState 和 validateState 方法
- 优化 LoginHandler 和 OAuthCallbackHandler 中的状态验证逻辑
- 添加更详细的调试和错误日志记录
- 完善回调地址生成逻辑,支持更多网络环境
- 在 OAuth 授权请求中添加 scope 参数
This commit is contained in:
wood chen 2025-03-12 20:27:20 +08:00
parent 2626f63770
commit 7f4a964163

View File

@ -18,6 +18,7 @@ import (
const ( const (
tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天 tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天
stateExpiry = 10 * time.Minute // State 过期时间为 10 分钟
) )
type OAuthUserInfo struct { type OAuthUserInfo struct {
@ -44,6 +45,7 @@ type OAuthToken struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
} }
type tokenInfo struct { type tokenInfo struct {
@ -52,6 +54,11 @@ type tokenInfo struct {
username string username string
} }
type stateInfo struct {
createdAt time.Time
expiresAt time.Time
}
type authManager struct { type authManager struct {
tokens sync.Map tokens sync.Map
states sync.Map states sync.Map
@ -60,6 +67,7 @@ type authManager struct {
func newAuthManager() *authManager { func newAuthManager() *authManager {
am := &authManager{} am := &authManager{}
go am.cleanExpiredTokens() go am.cleanExpiredTokens()
go am.cleanExpiredStates()
return am return am
} }
@ -69,6 +77,27 @@ func (am *authManager) generateToken() string {
return base64.URLEncoding.EncodeToString(b) return base64.URLEncoding.EncodeToString(b)
} }
func (am *authManager) generateState() string {
state := am.generateToken()
am.states.Store(state, stateInfo{
createdAt: time.Now(),
expiresAt: time.Now().Add(stateExpiry),
})
return state
}
func (am *authManager) validateState(state string) bool {
if info, ok := am.states.Load(state); ok {
stateInfo := info.(stateInfo)
if time.Now().Before(stateInfo.expiresAt) {
am.states.Delete(state) // 使用后立即删除
return true
}
am.states.Delete(state) // 过期也删除
}
return false
}
func (am *authManager) addToken(token string, username string, expiry time.Duration) { func (am *authManager) addToken(token string, username string, expiry time.Duration) {
am.tokens.Store(token, tokenInfo{ am.tokens.Store(token, tokenInfo{
createdAt: time.Now(), createdAt: time.Now(),
@ -102,6 +131,20 @@ func (am *authManager) cleanExpiredTokens() {
} }
} }
func (am *authManager) cleanExpiredStates() {
ticker := time.NewTicker(time.Minute)
for range ticker.C {
am.states.Range(func(key, value interface{}) bool {
state := key.(string)
info := value.(stateInfo)
if time.Now().After(info.expiresAt) {
am.states.Delete(state)
}
return true
})
}
}
// CheckAuth 检查认证令牌是否有效 // CheckAuth 检查认证令牌是否有效
func (h *ProxyHandler) CheckAuth(token string) bool { func (h *ProxyHandler) CheckAuth(token string) bool {
return h.auth.validateToken(token) return h.auth.validateToken(token)
@ -150,30 +193,48 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
// getCallbackURL 从请求中获取回调地址 // getCallbackURL 从请求中获取回调地址
func getCallbackURL(r *http.Request) string { func getCallbackURL(r *http.Request) string {
if os.Getenv("OAUTH_REDIRECT_URI") != "" { if redirectURI := os.Getenv("OAUTH_REDIRECT_URI"); redirectURI != "" {
return os.Getenv("OAUTH_REDIRECT_URI") // 验证URI格式
} else { if _, err := url.Parse(redirectURI); err == nil {
log.Printf("[Auth] DEBUG Using configured OAUTH_REDIRECT_URI: %s", redirectURI)
return redirectURI
}
log.Printf("[Auth] WARNING Invalid OAUTH_REDIRECT_URI format: %s", redirectURI)
}
// 更可靠地检测协议
scheme := "http" scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https" scheme = "https"
} }
return fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, r.Host)
// 考虑X-Forwarded-Host头
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
} }
callbackURL := fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, host)
log.Printf("[Auth] DEBUG Generated callback URL: %s", callbackURL)
return callbackURL
} }
// LoginHandler 处理登录请求,重定向到 OAuth 授权页面 // LoginHandler 处理登录请求,重定向到 OAuth 授权页面
func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
state := h.auth.generateToken() state := h.auth.generateState()
h.auth.states.Store(state, time.Now())
clientID := os.Getenv("OAUTH_CLIENT_ID") clientID := os.Getenv("OAUTH_CLIENT_ID")
redirectURI := getCallbackURL(r) redirectURI := getCallbackURL(r)
// 记录生成的state和重定向URI
log.Printf("[Auth] DEBUG %s %s -> Generated state=%s, redirect_uri=%s",
r.Method, r.URL.Path, state, redirectURI)
authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s", authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s",
url.Values{ url.Values{
"response_type": {"code"}, "response_type": {"code"},
"client_id": {clientID}, "client_id": {clientID},
"redirect_uri": {redirectURI}, "redirect_uri": {redirectURI},
"scope": {"read write"}, // 添加scope参数
"state": {state}, "state": {state},
}.Encode()) }.Encode())
@ -185,13 +246,17 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state") state := r.URL.Query().Get("state")
// 记录完整请求信息
log.Printf("[Auth] DEBUG %s %s -> Callback received with state=%s, code=%s, full URL: %s",
r.Method, r.URL.Path, state, code, r.URL.String())
// 验证 state // 验证 state
if _, ok := h.auth.states.Load(state); !ok { if !h.auth.validateState(state) {
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r)) log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r))
http.Error(w, "Invalid state", http.StatusBadRequest) http.Error(w, "Invalid state", http.StatusBadRequest)
return return
} }
h.auth.states.Delete(state)
// 验证code参数 // 验证code参数
if code == "" { if code == "" {
@ -212,6 +277,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
return return
} }
// 记录令牌交换请求信息
log.Printf("[Auth] DEBUG %s %s -> Exchanging code for token with redirect_uri=%s",
r.Method, r.URL.Path, redirectURI)
resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token", resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token",
url.Values{ url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
@ -229,8 +298,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 检查响应状态码 // 检查响应状态码
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error status: %s from %s", // 读取错误响应内容
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, utils.GetRequestSource(r)) bodyBytes, _ := io.ReadAll(resp.Body)
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s",
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes))
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError) http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
return return
} }