diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..f53293b --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +PORT=8080 +GIN_MODE=debug +OAUTH_CLIENT_ID=your_client_id_here +OAUTH_CLIENT_SECRET=your_client_secret_here +OAUTH_REDIRECT_URI=https://aimodels-prices.q58.pro/api/auth/callback +OAUTH_AUTHORIZE_URL=https://connect.q58.pro/oauth/authorize +OAUTH_TOKEN_URL=https://connect.q58.pro/api/oauth/access_token +OAUTH_USER_URL=https://connect.q58.pro/api/oauth/user \ No newline at end of file diff --git a/backend/handlers/auth.go b/backend/handlers/auth.go index b3473c9..6fd2772 100644 --- a/backend/handlers/auth.go +++ b/backend/handlers/auth.go @@ -1,17 +1,14 @@ package handlers import ( - "crypto/hmac" "crypto/rand" - "crypto/sha256" "database/sql" - "encoding/base64" "encoding/hex" + "encoding/json" "fmt" "net/http" "net/url" "os" - "strings" "time" "github.com/gin-gonic/gin" @@ -99,45 +96,24 @@ func Login(c *gin.Context) { return } - // 生产环境使用 Discourse SSO - discourseURL := os.Getenv("DISCOURSE_URL") - if discourseURL == "" { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Discourse URL not configured"}) + // 生产环境使用 OAuth 2.0 + clientID := os.Getenv("OAUTH_CLIENT_ID") + redirectURI := os.Getenv("OAUTH_REDIRECT_URI") + authorizeURL := os.Getenv("OAUTH_AUTHORIZE_URL") + + if clientID == "" || redirectURI == "" || authorizeURL == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OAuth configuration not found"}) return } - // 生成随机 nonce - nonce := make([]byte, 16) - if _, err := rand.Read(nonce); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate nonce"}) - return - } - nonceStr := hex.EncodeToString(nonce) + // 构建授权 URL + authURL := fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s", + authorizeURL, + url.QueryEscape(clientID), + url.QueryEscape(redirectURI)) - // 构建 payload - payload := url.Values{} - payload.Set("nonce", nonceStr) - payload.Set("return_sso_url", fmt.Sprintf("https://aimodels-prices.q58.pro/api/auth/callback")) - - // Base64 编码 - payloadStr := base64.StdEncoding.EncodeToString([]byte(payload.Encode())) - - // 计算签名 - ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET") - if ssoSecret == "" { - c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"}) - return - } - - h := hmac.New(sha256.New, []byte(ssoSecret)) - h.Write([]byte(payloadStr)) - sig := hex.EncodeToString(h.Sum(nil)) - - // 构建重定向 URL - redirectURL := fmt.Sprintf("%s/session/sso?sso=%s&sig=%s", - discourseURL, url.QueryEscape(payloadStr), sig) - - c.Redirect(http.StatusTemporaryRedirect, redirectURL) + // 重定向到授权页面 + c.Redirect(http.StatusTemporaryRedirect, authURL) } func Logout(c *gin.Context) { @@ -178,75 +154,94 @@ func GetUser(c *gin.Context) { } func AuthCallback(c *gin.Context) { - sso := c.Query("sso") - sig := c.Query("sig") - - if sso == "" || sig == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Missing parameters"}) + code := c.Query("code") + if code == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Missing authorization code"}) return } - // 获取 SSO 密钥 - ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET") - if ssoSecret == "" { - c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"}) - return - } + // 获取访问令牌 + tokenURL := os.Getenv("OAUTH_TOKEN_URL") + clientID := os.Getenv("OAUTH_CLIENT_ID") + clientSecret := os.Getenv("OAUTH_CLIENT_SECRET") + redirectURI := os.Getenv("OAUTH_REDIRECT_URI") - // 验证签名 - h := hmac.New(sha256.New, []byte(ssoSecret)) - h.Write([]byte(sso)) - computedSig := hex.EncodeToString(h.Sum(nil)) - if computedSig != sig { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid signature"}) - return - } + // 构建请求体 + data := url.Values{} + data.Set("code", code) + data.Set("client_id", clientID) + data.Set("client_secret", clientSecret) + data.Set("redirect_uri", redirectURI) + data.Set("grant_type", "authorization_code") - // 解码 SSO payload - payload, err := base64.StdEncoding.DecodeString(sso) + // 发送请求获取访问令牌 + resp, err := http.PostForm(tokenURL, data) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid SSO payload"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get access token"}) + return + } + defer resp.Body.Close() + + var tokenResp struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse token response"}) return } - // 解析 payload - values, err := url.ParseQuery(string(payload)) + // 使用访问令牌获取用户信息 + userURL := os.Getenv("OAUTH_USER_URL") + req, err := http.NewRequest("GET", userURL, nil) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid payload format"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user info request"}) return } - // 获取用户信息 - username := values.Get("username") - email := values.Get("email") - groups := values.Get("groups") - admin := values.Get("admin") // Discourse 管理员标志 - moderator := values.Get("moderator") // Discourse 版主标志 - if username == "" || email == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Missing user information"}) + req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) + client := &http.Client{} + userResp, err := client.Do(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"}) return } + defer userResp.Body.Close() - // 判断用户角色 - role := "user" - // 如果是管理员、版主或属于 admins 组,都赋予管理权限 - if admin == "true" || moderator == "true" || (groups != "" && strings.Contains(groups, "admins")) { - role = "admin" + var userInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Username string `json:"username"` + Admin bool `json:"admin"` + AvatarURL string `json:"avatar_url"` + Name string `json:"name"` + } + + if err := json.NewDecoder(userResp.Body).Decode(&userInfo); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse user info"}) + return } db := c.MustGet("db").(*sql.DB) // 检查用户是否存在 var user models.User - err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", email).Scan( + err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan( &user.ID, &user.Username, &user.Email, &user.Role) + role := "user" + if userInfo.Admin { + role = "admin" + } + if err == sql.ErrNoRows { // 创建新用户 result, err := db.Exec(` INSERT INTO user (username, email, role) VALUES (?, ?, ?)`, - username, email, role) + userInfo.Username, userInfo.Email, role) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) return @@ -254,8 +249,8 @@ func AuthCallback(c *gin.Context) { userID, _ := result.LastInsertId() user = models.User{ ID: uint(userID), - Username: username, - Email: email, + Username: userInfo.Username, + Email: userInfo.Email, Role: role, } } else if err != nil {