mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 08:31:55 +08:00
feat(auth): 增强OAuth认证状态管理和安全性
- 新增 state 状态管理机制,增加 10 分钟有效期 - 实现 generateState 和 validateState 方法 - 优化 LoginHandler 和 OAuthCallbackHandler 中的状态验证逻辑 - 添加更详细的调试和错误日志记录 - 完善回调地址生成逻辑,支持更多网络环境 - 在 OAuth 授权请求中添加 scope 参数
This commit is contained in:
parent
2626f63770
commit
7f4a964163
@ -18,6 +18,7 @@ import (
|
||||
|
||||
const (
|
||||
tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天
|
||||
stateExpiry = 10 * time.Minute // State 过期时间为 10 分钟
|
||||
)
|
||||
|
||||
type OAuthUserInfo struct {
|
||||
@ -41,9 +42,10 @@ type OAuthUserInfo struct {
|
||||
}
|
||||
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
type tokenInfo struct {
|
||||
@ -52,6 +54,11 @@ type tokenInfo struct {
|
||||
username string
|
||||
}
|
||||
|
||||
type stateInfo struct {
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type authManager struct {
|
||||
tokens sync.Map
|
||||
states sync.Map
|
||||
@ -60,6 +67,7 @@ type authManager struct {
|
||||
func newAuthManager() *authManager {
|
||||
am := &authManager{}
|
||||
go am.cleanExpiredTokens()
|
||||
go am.cleanExpiredStates()
|
||||
return am
|
||||
}
|
||||
|
||||
@ -69,6 +77,27 @@ func (am *authManager) generateToken() string {
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func (am *authManager) generateState() string {
|
||||
state := am.generateToken()
|
||||
am.states.Store(state, stateInfo{
|
||||
createdAt: time.Now(),
|
||||
expiresAt: time.Now().Add(stateExpiry),
|
||||
})
|
||||
return state
|
||||
}
|
||||
|
||||
func (am *authManager) validateState(state string) bool {
|
||||
if info, ok := am.states.Load(state); ok {
|
||||
stateInfo := info.(stateInfo)
|
||||
if time.Now().Before(stateInfo.expiresAt) {
|
||||
am.states.Delete(state) // 使用后立即删除
|
||||
return true
|
||||
}
|
||||
am.states.Delete(state) // 过期也删除
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (am *authManager) addToken(token string, username string, expiry time.Duration) {
|
||||
am.tokens.Store(token, tokenInfo{
|
||||
createdAt: time.Now(),
|
||||
@ -102,6 +131,20 @@ func (am *authManager) cleanExpiredTokens() {
|
||||
}
|
||||
}
|
||||
|
||||
func (am *authManager) cleanExpiredStates() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
for range ticker.C {
|
||||
am.states.Range(func(key, value interface{}) bool {
|
||||
state := key.(string)
|
||||
info := value.(stateInfo)
|
||||
if time.Now().After(info.expiresAt) {
|
||||
am.states.Delete(state)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAuth 检查认证令牌是否有效
|
||||
func (h *ProxyHandler) CheckAuth(token string) bool {
|
||||
return h.auth.validateToken(token)
|
||||
@ -150,30 +193,48 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
// getCallbackURL 从请求中获取回调地址
|
||||
func getCallbackURL(r *http.Request) string {
|
||||
if os.Getenv("OAUTH_REDIRECT_URI") != "" {
|
||||
return os.Getenv("OAUTH_REDIRECT_URI")
|
||||
} else {
|
||||
scheme := "http"
|
||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
if redirectURI := os.Getenv("OAUTH_REDIRECT_URI"); redirectURI != "" {
|
||||
// 验证URI格式
|
||||
if _, err := url.Parse(redirectURI); err == nil {
|
||||
log.Printf("[Auth] DEBUG Using configured OAUTH_REDIRECT_URI: %s", redirectURI)
|
||||
return redirectURI
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, r.Host)
|
||||
log.Printf("[Auth] WARNING Invalid OAUTH_REDIRECT_URI format: %s", redirectURI)
|
||||
}
|
||||
|
||||
// 更可靠地检测协议
|
||||
scheme := "http"
|
||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// 考虑X-Forwarded-Host头
|
||||
host := r.Host
|
||||
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
callbackURL := fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, host)
|
||||
log.Printf("[Auth] DEBUG Generated callback URL: %s", callbackURL)
|
||||
return callbackURL
|
||||
}
|
||||
|
||||
// LoginHandler 处理登录请求,重定向到 OAuth 授权页面
|
||||
func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
state := h.auth.generateToken()
|
||||
h.auth.states.Store(state, time.Now())
|
||||
|
||||
state := h.auth.generateState()
|
||||
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
||||
redirectURI := getCallbackURL(r)
|
||||
|
||||
// 记录生成的state和重定向URI
|
||||
log.Printf("[Auth] DEBUG %s %s -> Generated state=%s, redirect_uri=%s",
|
||||
r.Method, r.URL.Path, state, redirectURI)
|
||||
|
||||
authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s",
|
||||
url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {clientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {"read write"}, // 添加scope参数
|
||||
"state": {state},
|
||||
}.Encode())
|
||||
|
||||
@ -185,13 +246,17 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
// 记录完整请求信息
|
||||
log.Printf("[Auth] DEBUG %s %s -> Callback received with state=%s, code=%s, full URL: %s",
|
||||
r.Method, r.URL.Path, state, code, r.URL.String())
|
||||
|
||||
// 验证 state
|
||||
if _, ok := h.auth.states.Load(state); !ok {
|
||||
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r))
|
||||
if !h.auth.validateState(state) {
|
||||
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s",
|
||||
r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r))
|
||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h.auth.states.Delete(state)
|
||||
|
||||
// 验证code参数
|
||||
if code == "" {
|
||||
@ -212,6 +277,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
// 记录令牌交换请求信息
|
||||
log.Printf("[Auth] DEBUG %s %s -> Exchanging code for token with redirect_uri=%s",
|
||||
r.Method, r.URL.Path, redirectURI)
|
||||
|
||||
resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token",
|
||||
url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
@ -229,8 +298,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
// 检查响应状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error status: %s from %s",
|
||||
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, utils.GetRequestSource(r))
|
||||
// 读取错误响应内容
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s",
|
||||
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes))
|
||||
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user