mirror of
https://github.com/woodchen-ink/proxy-go.git
synced 2025-07-18 16:41:54 +08:00
feat(auth): 增强OAuth认证状态管理和安全性
- 新增 state 状态管理机制,增加 10 分钟有效期 - 实现 generateState 和 validateState 方法 - 优化 LoginHandler 和 OAuthCallbackHandler 中的状态验证逻辑 - 添加更详细的调试和错误日志记录 - 完善回调地址生成逻辑,支持更多网络环境 - 在 OAuth 授权请求中添加 scope 参数
This commit is contained in:
parent
2626f63770
commit
7f4a964163
@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天
|
tokenExpiry = 30 * 24 * time.Hour // Token 过期时间为 30 天
|
||||||
|
stateExpiry = 10 * time.Minute // State 过期时间为 10 分钟
|
||||||
)
|
)
|
||||||
|
|
||||||
type OAuthUserInfo struct {
|
type OAuthUserInfo struct {
|
||||||
@ -44,6 +45,7 @@ type OAuthToken struct {
|
|||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenInfo struct {
|
type tokenInfo struct {
|
||||||
@ -52,6 +54,11 @@ type tokenInfo struct {
|
|||||||
username string
|
username string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stateInfo struct {
|
||||||
|
createdAt time.Time
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type authManager struct {
|
type authManager struct {
|
||||||
tokens sync.Map
|
tokens sync.Map
|
||||||
states sync.Map
|
states sync.Map
|
||||||
@ -60,6 +67,7 @@ type authManager struct {
|
|||||||
func newAuthManager() *authManager {
|
func newAuthManager() *authManager {
|
||||||
am := &authManager{}
|
am := &authManager{}
|
||||||
go am.cleanExpiredTokens()
|
go am.cleanExpiredTokens()
|
||||||
|
go am.cleanExpiredStates()
|
||||||
return am
|
return am
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,6 +77,27 @@ func (am *authManager) generateToken() string {
|
|||||||
return base64.URLEncoding.EncodeToString(b)
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *authManager) generateState() string {
|
||||||
|
state := am.generateToken()
|
||||||
|
am.states.Store(state, stateInfo{
|
||||||
|
createdAt: time.Now(),
|
||||||
|
expiresAt: time.Now().Add(stateExpiry),
|
||||||
|
})
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *authManager) validateState(state string) bool {
|
||||||
|
if info, ok := am.states.Load(state); ok {
|
||||||
|
stateInfo := info.(stateInfo)
|
||||||
|
if time.Now().Before(stateInfo.expiresAt) {
|
||||||
|
am.states.Delete(state) // 使用后立即删除
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
am.states.Delete(state) // 过期也删除
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (am *authManager) addToken(token string, username string, expiry time.Duration) {
|
func (am *authManager) addToken(token string, username string, expiry time.Duration) {
|
||||||
am.tokens.Store(token, tokenInfo{
|
am.tokens.Store(token, tokenInfo{
|
||||||
createdAt: time.Now(),
|
createdAt: time.Now(),
|
||||||
@ -102,6 +131,20 @@ func (am *authManager) cleanExpiredTokens() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *authManager) cleanExpiredStates() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
for range ticker.C {
|
||||||
|
am.states.Range(func(key, value interface{}) bool {
|
||||||
|
state := key.(string)
|
||||||
|
info := value.(stateInfo)
|
||||||
|
if time.Now().After(info.expiresAt) {
|
||||||
|
am.states.Delete(state)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CheckAuth 检查认证令牌是否有效
|
// CheckAuth 检查认证令牌是否有效
|
||||||
func (h *ProxyHandler) CheckAuth(token string) bool {
|
func (h *ProxyHandler) CheckAuth(token string) bool {
|
||||||
return h.auth.validateToken(token)
|
return h.auth.validateToken(token)
|
||||||
@ -150,30 +193,48 @@ func (h *ProxyHandler) AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
|
|
||||||
// getCallbackURL 从请求中获取回调地址
|
// getCallbackURL 从请求中获取回调地址
|
||||||
func getCallbackURL(r *http.Request) string {
|
func getCallbackURL(r *http.Request) string {
|
||||||
if os.Getenv("OAUTH_REDIRECT_URI") != "" {
|
if redirectURI := os.Getenv("OAUTH_REDIRECT_URI"); redirectURI != "" {
|
||||||
return os.Getenv("OAUTH_REDIRECT_URI")
|
// 验证URI格式
|
||||||
} else {
|
if _, err := url.Parse(redirectURI); err == nil {
|
||||||
|
log.Printf("[Auth] DEBUG Using configured OAUTH_REDIRECT_URI: %s", redirectURI)
|
||||||
|
return redirectURI
|
||||||
|
}
|
||||||
|
log.Printf("[Auth] WARNING Invalid OAUTH_REDIRECT_URI format: %s", redirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更可靠地检测协议
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, r.Host)
|
|
||||||
|
// 考虑X-Forwarded-Host头
|
||||||
|
host := r.Host
|
||||||
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||||
|
host = forwardedHost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callbackURL := fmt.Sprintf("%s://%s/admin/api/oauth/callback", scheme, host)
|
||||||
|
log.Printf("[Auth] DEBUG Generated callback URL: %s", callbackURL)
|
||||||
|
return callbackURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginHandler 处理登录请求,重定向到 OAuth 授权页面
|
// LoginHandler 处理登录请求,重定向到 OAuth 授权页面
|
||||||
func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *ProxyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
state := h.auth.generateToken()
|
state := h.auth.generateState()
|
||||||
h.auth.states.Store(state, time.Now())
|
|
||||||
|
|
||||||
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
||||||
redirectURI := getCallbackURL(r)
|
redirectURI := getCallbackURL(r)
|
||||||
|
|
||||||
|
// 记录生成的state和重定向URI
|
||||||
|
log.Printf("[Auth] DEBUG %s %s -> Generated state=%s, redirect_uri=%s",
|
||||||
|
r.Method, r.URL.Path, state, redirectURI)
|
||||||
|
|
||||||
authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s",
|
authURL := fmt.Sprintf("https://connect.czl.net/oauth2/authorize?%s",
|
||||||
url.Values{
|
url.Values{
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"client_id": {clientID},
|
"client_id": {clientID},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {redirectURI},
|
||||||
|
"scope": {"read write"}, // 添加scope参数
|
||||||
"state": {state},
|
"state": {state},
|
||||||
}.Encode())
|
}.Encode())
|
||||||
|
|
||||||
@ -185,13 +246,17 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
|||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
// 记录完整请求信息
|
||||||
|
log.Printf("[Auth] DEBUG %s %s -> Callback received with state=%s, code=%s, full URL: %s",
|
||||||
|
r.Method, r.URL.Path, state, code, r.URL.String())
|
||||||
|
|
||||||
// 验证 state
|
// 验证 state
|
||||||
if _, ok := h.auth.states.Load(state); !ok {
|
if !h.auth.validateState(state) {
|
||||||
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state from %s", r.Method, r.URL.Path, utils.GetClientIP(r), utils.GetRequestSource(r))
|
log.Printf("[Auth] ERR %s %s -> 400 (%s) invalid state '%s' from %s",
|
||||||
|
r.Method, r.URL.Path, utils.GetClientIP(r), state, utils.GetRequestSource(r))
|
||||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.auth.states.Delete(state)
|
|
||||||
|
|
||||||
// 验证code参数
|
// 验证code参数
|
||||||
if code == "" {
|
if code == "" {
|
||||||
@ -212,6 +277,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录令牌交换请求信息
|
||||||
|
log.Printf("[Auth] DEBUG %s %s -> Exchanging code for token with redirect_uri=%s",
|
||||||
|
r.Method, r.URL.Path, redirectURI)
|
||||||
|
|
||||||
resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token",
|
resp, err := http.PostForm("https://connect.czl.net/api/oauth2/token",
|
||||||
url.Values{
|
url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
@ -229,8 +298,10 @@ func (h *ProxyHandler) OAuthCallbackHandler(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// 检查响应状态码
|
// 检查响应状态码
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error status: %s from %s",
|
// 读取错误响应内容
|
||||||
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, utils.GetRequestSource(r))
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
log.Printf("[Auth] ERR %s %s -> %d (%s) OAuth server returned error: %s, response: %s",
|
||||||
|
r.Method, r.URL.Path, resp.StatusCode, utils.GetClientIP(r), resp.Status, string(bodyBytes))
|
||||||
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
|
http.Error(w, "OAuth server error: "+resp.Status, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user