feat(auth): 增强OAuth认证状态管理和安全性

- 新增 state 状态管理机制,增加 10 分钟有效期
- 实现 generateState 和 validateState 方法
- 优化 LoginHandler 和 OAuthCallbackHandler 中的状态验证逻辑
- 添加更详细的调试和错误日志记录
- 完善回调地址生成逻辑,支持更多网络环境
- 在 OAuth 授权请求中添加 scope 参数
This commit is contained in:
wood chen 2025-03-12 20:27:20 +08:00
parent 2626f63770
commit 7f4a964163

View File

@ -18,6 +18,7 @@ import (
const (
tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天
stateExpiry = 10 * time.Minute // State 过期时间为 10 分钟
)
type OAuthUserInfo struct {
@ -41,9 +42,10 @@ type OAuthUserInfo struct {
}
type OAuthToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
}
type tokenInfo struct {
@ -52,6 +54,11 @@ type tokenInfo struct {
username string
}
type stateInfo struct {
createdAt time.Time
expiresAt time.Time
}
type authManager struct {
tokens sync.Map
states sync.Map
@ -60,6 +67,7 @@ type authManager struct {
func newAuthManager() *authManager {
am := &authManager{}
go am.cleanExpiredTokens()
go am.cleanExpiredStates()
return am
}
@ -69,6 +77,27 @@ func (am *authManager) generateToken() string {
return base64.URLEncoding.EncodeToString(b)
}
func (am *authManager) generateState() string {
state := am.generateToken()
am.states.Store(state, stateInfo{
createdAt: time.Now(),
expiresAt: time.Now().Add(stateExpiry),
})
return state
}
func (am *authManager) validateState(state string) bool {
if info, ok := am.states.Load(state); ok {
stateInfo := info.(stateInfo)
if time.Now().Before(stateInfo.expiresAt) {
am.states.Delete(state) // 使用后立即删除
return true
}
am.states.Delete(state) // 过期也删除
}
return false
}
func (am *authManager) addToken(token string, username string, expiry time.Duration) {
am.tokens.Store(token, tokenInfo{
createdAt: time.Now(),
@ -102,6 +131,20 @@ func (am *authManager) cleanExpiredTokens() {
}
}
func (am *authManager) cleanExpiredStates() {
ticker := time.NewTicker(time.Minute)
for range ticker.C {
am.states.Range(func(key, value interface{}) bool {
state := key.(string)
info := value.(stateInfo)
if time.Now().After(info.expiresAt) {
am.states.Delete(state)
}
return true
})
}
}
// CheckAuth 检查认证令牌是否有效
func (h *ProxyHandler) CheckAuth(token string) bool {
return h.auth.validateToken(token)
@ -150,30 +193,48 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
// getCallbackURL 从请求中获取回调地址
func getCallbackURL(r *http.Request) string {
if os.Getenv("OAUTH_REDIRECT_URI") != "" {
return os.Getenv("OAUTH_REDIRECT_URI")
} else {
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
if redirectURI := os.Getenv("OAUTH_REDIRECT_URI"); redirectURI != "" {
// 验证URI格式
if _, err := url.Parse(redirectURI); err == nil {
log.Printf("[Auth] DEBUG Using configured OAUTH_REDIRECT_URI: %s", redirectURI)
return redirectURI
}
return fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, r.Host)
log.Printf("[Auth] WARNING Invalid OAUTH_REDIRECT_URI format: %s", redirectURI)
}
// 更可靠地检测协议
scheme := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
}
// 考虑X-Forwarded-Host头
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
callbackURL := fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, host)
log.Printf("[Auth] DEBUG Generated callback URL: %s", callbackURL)
return callbackURL
}
// LoginHandler 处理登录请求,重定向到 OAuth 授权页面
func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
state := h.auth.generateToken()
h.auth.states.Store(state, time.Now())
state := h.auth.generateState()
clientID := os.Getenv("OAUTH_CLIENT_ID")
redirectURI := getCallbackURL(r)
// 记录生成的state和重定向URI
log.Printf("[Auth] DEBUG %s %s -> Generated state=%s, redirect_uri=%s",
r.Method, r.URL.Path, state, redirectURI)
authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s",
url.Values{
"response_type": {"code"},
"client_id": {clientID},
"redirect_uri": {redirectURI},
"scope": {"read write"}, // 添加scope参数
"state": {state},
}.Encode())
@ -185,13 +246,17 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
// 记录完整请求信息
log.Printf("[Auth] DEBUG %s %s -> Callback received with state=%s, code=%s, full URL: %s",
r.Method, r.URL.Path, state, code, r.URL.String())
// 验证 state
if _, ok := h.auth.states.Load(state); !ok {
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r))
if !h.auth.validateState(state) {
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s",
r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r))
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
h.auth.states.Delete(state)
// 验证code参数
if code == "" {
@ -212,6 +277,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
return
}
// 记录令牌交换请求信息
log.Printf("[Auth] DEBUG %s %s -> Exchanging code for token with redirect_uri=%s",
r.Method, r.URL.Path, redirectURI)
resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token",
url.Values{
"grant_type": {"authorization_code"},
@ -229,8 +298,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
// 检查响应状态码
if resp.StatusCode != http.StatusOK {
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error status: %s from %s",
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, utils.GetRequestSource(r))
// 读取错误响应内容
bodyBytes, _ := io.ReadAll(resp.Body)
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s",
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes))
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
return
}