package middleware import ( "encoding/json" "fmt" "log" "net/http" "strings" ) // UserInfo OAuth用户信息结构 type UserInfo struct { ID int `json:"id"` Username string `json:"username"` Nickname string `json:"nickname"` Email string `json:"email"` Avatar string `json:"avatar"` } // AuthMiddleware 认证中间件 type AuthMiddleware struct{} // NewAuthMiddleware 创建新的认证中间件 func NewAuthMiddleware() *AuthMiddleware { return &AuthMiddleware{} } // RequireAuth 认证中间件,验证 OAuth 令牌 func (am *AuthMiddleware) RequireAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 从 Authorization header 获取令牌 authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Authorization header required", http.StatusUnauthorized) return } // 检查 Bearer 前缀 if !strings.HasPrefix(authHeader, "Bearer ") { http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) return } token := strings.TrimPrefix(authHeader, "Bearer ") if token == "" { http.Error(w, "Token required", http.StatusUnauthorized) return } // 验证令牌(通过调用用户信息接口) userInfo, err := am.getUserInfo(token) if err != nil { log.Printf("Token validation failed: %v", err) http.Error(w, "Invalid token", http.StatusUnauthorized) return } // 令牌有效,继续处理请求 log.Printf("Authenticated user: %s (%s)", userInfo.Username, userInfo.Email) next(w, r) } } // getUserInfo 通过访问令牌获取用户信息 func (am *AuthMiddleware) getUserInfo(accessToken string) (*UserInfo, error) { req, err := http.NewRequest("GET", "https://connect.czl.net/api/oauth2/userinfo", nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get user info, status: %d", resp.StatusCode) } var userInfo UserInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return nil, fmt.Errorf("failed to decode user info: %w", err) } return &userInfo, nil }