91 lines
2.4 KiB
Go

package middleware
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
)
// UserInfo OAuth用户信息结构
type UserInfo struct {
ID int `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email string `json:"email"`
Avatar string `json:"avatar"`
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct{}
// NewAuthMiddleware 创建新的认证中间件
func NewAuthMiddleware() *AuthMiddleware {
return &AuthMiddleware{}
}
// RequireAuth 认证中间件,验证 OAuth 令牌
func (am *AuthMiddleware) RequireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 从 Authorization header 获取令牌
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// 检查 Bearer 前缀
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
http.Error(w, "Token required", http.StatusUnauthorized)
return
}
// 验证令牌(通过调用用户信息接口)
userInfo, err := am.getUserInfo(token)
if err != nil {
log.Printf("Token validation failed: %v", err)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// 令牌有效,继续处理请求
log.Printf("Authenticated user: %s (%s)", userInfo.Username, userInfo.Email)
next(w, r)
}
}
// getUserInfo 通过访问令牌获取用户信息
func (am *AuthMiddleware) getUserInfo(accessToken string) (*UserInfo, error) {
req, err := http.NewRequest("GET", "https://connect.czl.net/api/oauth2/userinfo", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get user info, status: %d", resp.StatusCode)
}
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
return &userInfo, nil
}