142 lines
3.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
// UserInfo OAuth用户信息结构
type UserInfo struct {
ID int `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email string `json:"email"`
Avatar string `json:"avatar"`
}
// TokenCache token缓存项
type TokenCache struct {
UserInfo *UserInfo
ExpiresAt time.Time
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
tokenCache sync.Map // map[string]*TokenCache
cacheTTL time.Duration
}
// NewAuthMiddleware 创建新的认证中间件
func NewAuthMiddleware() *AuthMiddleware {
return &AuthMiddleware{
cacheTTL: 30 * time.Minute, // token缓存30分钟
}
}
// RequireAuth 认证中间件,验证 OAuth 令牌
func (am *AuthMiddleware) RequireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 从 Authorization header 获取令牌
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// 检查 Bearer 前缀
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
http.Error(w, "Token required", http.StatusUnauthorized)
return
}
// 先检查缓存
if cached, ok := am.tokenCache.Load(token); ok {
tokenCache := cached.(*TokenCache)
// 检查缓存是否过期
if time.Now().Before(tokenCache.ExpiresAt) {
// 缓存有效,直接通过
next(w, r)
return
} else {
// 缓存过期,删除
am.tokenCache.Delete(token)
}
}
// 缓存未命中或已过期,验证令牌
userInfo, err := am.getUserInfo(token)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// 将结果缓存
am.tokenCache.Store(token, &TokenCache{
UserInfo: userInfo,
ExpiresAt: time.Now().Add(am.cacheTTL),
})
// 验证成功,继续处理请求
next(w, r)
}
}
// getUserInfo 通过访问令牌获取用户信息
func (am *AuthMiddleware) getUserInfo(accessToken string) (*UserInfo, error) {
req, err := http.NewRequest("GET", "https://connect.czl.net/api/oauth2/userinfo", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{
Timeout: 10 * time.Second, // 添加超时时间
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get user info, status: %d", resp.StatusCode)
}
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
return &userInfo, nil
}
// InvalidateToken 使token缓存失效用于登出等场景
func (am *AuthMiddleware) InvalidateToken(token string) {
am.tokenCache.Delete(token)
}
// GetCacheStats 获取缓存统计信息(用于监控)
func (am *AuthMiddleware) GetCacheStats() map[string]interface{} {
count := 0
am.tokenCache.Range(func(key, value interface{}) bool {
count++
return true
})
return map[string]interface{}{
"cached_tokens": count,
"cache_ttl": am.cacheTTL.String(),
}
}