Migrate authentication from Discourse SSO to OAuth 2.0 authentication flow

This commit is contained in:
wood chen 2025-02-08 20:07:42 +08:00
parent 8a2aec1ca3
commit bd8656542b
2 changed files with 84 additions and 81 deletions

8
.env.example Normal file
View File

@ -0,0 +1,8 @@
PORT=8080
GIN_MODE=debug
OAUTH_CLIENT_ID=your_client_id_here
OAUTH_CLIENT_SECRET=your_client_secret_here
OAUTH_REDIRECT_URI=https://aimodels-prices.q58.pro/api/auth/callback
OAUTH_AUTHORIZE_URL=https://connect.q58.pro/oauth/authorize
OAUTH_TOKEN_URL=https://connect.q58.pro/api/oauth/access_token
OAUTH_USER_URL=https://connect.q58.pro/api/oauth/user

View File

@ -1,17 +1,14 @@
package handlers package handlers
import ( import (
"crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"database/sql" "database/sql"
"encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -99,45 +96,24 @@ func Login(c *gin.Context) {
return return
} }
// 生产环境使用 Discourse SSO // 生产环境使用 OAuth 2.0
discourseURL := os.Getenv("DISCOURSE_URL") clientID := os.Getenv("OAUTH_CLIENT_ID")
if discourseURL == "" { redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Discourse URL not configured"}) authorizeURL := os.Getenv("OAUTH_AUTHORIZE_URL")
if clientID == "" || redirectURI == "" || authorizeURL == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "OAuth configuration not found"})
return return
} }
// 生成随机 nonce // 构建授权 URL
nonce := make([]byte, 16) authURL := fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s",
if _, err := rand.Read(nonce); err != nil { authorizeURL,
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate nonce"}) url.QueryEscape(clientID),
return url.QueryEscape(redirectURI))
}
nonceStr := hex.EncodeToString(nonce)
// 构建 payload // 重定向到授权页面
payload := url.Values{} c.Redirect(http.StatusTemporaryRedirect, authURL)
payload.Set("nonce", nonceStr)
payload.Set("return_sso_url", fmt.Sprintf("https://aimodels-prices.q58.pro/api/auth/callback"))
// Base64 编码
payloadStr := base64.StdEncoding.EncodeToString([]byte(payload.Encode()))
// 计算签名
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET")
if ssoSecret == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"})
return
}
h := hmac.New(sha256.New, []byte(ssoSecret))
h.Write([]byte(payloadStr))
sig := hex.EncodeToString(h.Sum(nil))
// 构建重定向 URL
redirectURL := fmt.Sprintf("%s/session/sso?sso=%s&sig=%s",
discourseURL, url.QueryEscape(payloadStr), sig)
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func Logout(c *gin.Context) { func Logout(c *gin.Context) {
@ -178,75 +154,94 @@ func GetUser(c *gin.Context) {
} }
func AuthCallback(c *gin.Context) { func AuthCallback(c *gin.Context) {
sso := c.Query("sso") code := c.Query("code")
sig := c.Query("sig") if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing authorization code"})
if sso == "" || sig == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing parameters"})
return return
} }
// 获取 SSO 密钥 // 获取访问令牌
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET") tokenURL := os.Getenv("OAUTH_TOKEN_URL")
if ssoSecret == "" { clientID := os.Getenv("OAUTH_CLIENT_ID")
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"}) clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
return redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
}
// 验证签名 // 构建请求体
h := hmac.New(sha256.New, []byte(ssoSecret)) data := url.Values{}
h.Write([]byte(sso)) data.Set("code", code)
computedSig := hex.EncodeToString(h.Sum(nil)) data.Set("client_id", clientID)
if computedSig != sig { data.Set("client_secret", clientSecret)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid signature"}) data.Set("redirect_uri", redirectURI)
return data.Set("grant_type", "authorization_code")
}
// 解码 SSO payload // 发送请求获取访问令牌
payload, err := base64.StdEncoding.DecodeString(sso) resp, err := http.PostForm(tokenURL, data)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid SSO payload"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get access token"})
return
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse token response"})
return return
} }
// 解析 payload // 使用访问令牌获取用户信息
values, err := url.ParseQuery(string(payload)) userURL := os.Getenv("OAUTH_USER_URL")
req, err := http.NewRequest("GET", userURL, nil)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid payload format"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user info request"})
return return
} }
// 获取用户信息 req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
username := values.Get("username") client := &http.Client{}
email := values.Get("email") userResp, err := client.Do(req)
groups := values.Get("groups") if err != nil {
admin := values.Get("admin") // Discourse 管理员标志 c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
moderator := values.Get("moderator") // Discourse 版主标志
if username == "" || email == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing user information"})
return return
} }
defer userResp.Body.Close()
// 判断用户角色 var userInfo struct {
role := "user" ID string `json:"id"`
// 如果是管理员、版主或属于 admins 组,都赋予管理权限 Email string `json:"email"`
if admin == "true" || moderator == "true" || (groups != "" && strings.Contains(groups, "admins")) { Username string `json:"username"`
role = "admin" Admin bool `json:"admin"`
AvatarURL string `json:"avatar_url"`
Name string `json:"name"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userInfo); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse user info"})
return
} }
db := c.MustGet("db").(*sql.DB) db := c.MustGet("db").(*sql.DB)
// 检查用户是否存在 // 检查用户是否存在
var user models.User var user models.User
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", email).Scan( err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan(
&user.ID, &user.Username, &user.Email, &user.Role) &user.ID, &user.Username, &user.Email, &user.Role)
role := "user"
if userInfo.Admin {
role = "admin"
}
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// 创建新用户 // 创建新用户
result, err := db.Exec(` result, err := db.Exec(`
INSERT INTO user (username, email, role) INSERT INTO user (username, email, role)
VALUES (?, ?, ?)`, VALUES (?, ?, ?)`,
username, email, role) userInfo.Username, userInfo.Email, role)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return return
@ -254,8 +249,8 @@ func AuthCallback(c *gin.Context) {
userID, _ := result.LastInsertId() userID, _ := result.LastInsertId()
user = models.User{ user = models.User{
ID: uint(userID), ID: uint(userID),
Username: username, Username: userInfo.Username,
Email: email, Email: userInfo.Email,
Role: role, Role: role,
} }
} else if err != nil { } else if err != nil {