Migrate authentication from Discourse SSO to OAuth 2.0 authentication flow

This commit is contained in:
wood chen 2025-02-08 20:07:42 +08:00
parent 8a2aec1ca3
commit bd8656542b
2 changed files with 84 additions and 81 deletions

8
.env.example Normal file
View File

@ -0,0 +1,8 @@
PORT=8080
GIN_MODE=debug
OAUTH_CLIENT_ID=your_client_id_here
OAUTH_CLIENT_SECRET=your_client_secret_here
OAUTH_REDIRECT_URI=https://aimodels-prices.q58.pro/api/auth/callback
OAUTH_AUTHORIZE_URL=https://connect.q58.pro/oauth/authorize
OAUTH_TOKEN_URL=https://connect.q58.pro/api/oauth/access_token
OAUTH_USER_URL=https://connect.q58.pro/api/oauth/user

View File

@ -1,17 +1,14 @@
package handlers
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
@ -99,45 +96,24 @@ func Login(c *gin.Context) {
return
}
// 生产环境使用 Discourse SSO
discourseURL := os.Getenv("DISCOURSE_URL")
if discourseURL == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Discourse URL not configured"})
// 生产环境使用 OAuth 2.0
clientID := os.Getenv("OAUTH_CLIENT_ID")
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
authorizeURL := os.Getenv("OAUTH_AUTHORIZE_URL")
if clientID == "" || redirectURI == "" || authorizeURL == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "OAuth configuration not found"})
return
}
// 生成随机 nonce
nonce := make([]byte, 16)
if _, err := rand.Read(nonce); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate nonce"})
return
}
nonceStr := hex.EncodeToString(nonce)
// 构建授权 URL
authURL := fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s",
authorizeURL,
url.QueryEscape(clientID),
url.QueryEscape(redirectURI))
// 构建 payload
payload := url.Values{}
payload.Set("nonce", nonceStr)
payload.Set("return_sso_url", fmt.Sprintf("https://aimodels-prices.q58.pro/api/auth/callback"))
// Base64 编码
payloadStr := base64.StdEncoding.EncodeToString([]byte(payload.Encode()))
// 计算签名
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET")
if ssoSecret == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"})
return
}
h := hmac.New(sha256.New, []byte(ssoSecret))
h.Write([]byte(payloadStr))
sig := hex.EncodeToString(h.Sum(nil))
// 构建重定向 URL
redirectURL := fmt.Sprintf("%s/session/sso?sso=%s&sig=%s",
discourseURL, url.QueryEscape(payloadStr), sig)
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
// 重定向到授权页面
c.Redirect(http.StatusTemporaryRedirect, authURL)
}
func Logout(c *gin.Context) {
@ -178,75 +154,94 @@ func GetUser(c *gin.Context) {
}
func AuthCallback(c *gin.Context) {
sso := c.Query("sso")
sig := c.Query("sig")
if sso == "" || sig == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing parameters"})
code := c.Query("code")
if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing authorization code"})
return
}
// 获取 SSO 密钥
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET")
if ssoSecret == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"})
return
}
// 获取访问令牌
tokenURL := os.Getenv("OAUTH_TOKEN_URL")
clientID := os.Getenv("OAUTH_CLIENT_ID")
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
// 验证签名
h := hmac.New(sha256.New, []byte(ssoSecret))
h.Write([]byte(sso))
computedSig := hex.EncodeToString(h.Sum(nil))
if computedSig != sig {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid signature"})
return
}
// 构建请求体
data := url.Values{}
data.Set("code", code)
data.Set("client_id", clientID)
data.Set("client_secret", clientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
// 解码 SSO payload
payload, err := base64.StdEncoding.DecodeString(sso)
// 发送请求获取访问令牌
resp, err := http.PostForm(tokenURL, data)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid SSO payload"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get access token"})
return
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse token response"})
return
}
// 解析 payload
values, err := url.ParseQuery(string(payload))
// 使用访问令牌获取用户信息
userURL := os.Getenv("OAUTH_USER_URL")
req, err := http.NewRequest("GET", userURL, nil)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid payload format"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user info request"})
return
}
// 获取用户信息
username := values.Get("username")
email := values.Get("email")
groups := values.Get("groups")
admin := values.Get("admin") // Discourse 管理员标志
moderator := values.Get("moderator") // Discourse 版主标志
if username == "" || email == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing user information"})
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
client := &http.Client{}
userResp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}
defer userResp.Body.Close()
// 判断用户角色
role := "user"
// 如果是管理员、版主或属于 admins 组,都赋予管理权限
if admin == "true" || moderator == "true" || (groups != "" && strings.Contains(groups, "admins")) {
role = "admin"
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Admin bool `json:"admin"`
AvatarURL string `json:"avatar_url"`
Name string `json:"name"`
}
if err := json.NewDecoder(userResp.Body).Decode(&userInfo); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse user info"})
return
}
db := c.MustGet("db").(*sql.DB)
// 检查用户是否存在
var user models.User
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", email).Scan(
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan(
&user.ID, &user.Username, &user.Email, &user.Role)
role := "user"
if userInfo.Admin {
role = "admin"
}
if err == sql.ErrNoRows {
// 创建新用户
result, err := db.Exec(`
INSERT INTO user (username, email, role)
VALUES (?, ?, ?)`,
username, email, role)
userInfo.Username, userInfo.Email, role)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
@ -254,8 +249,8 @@ func AuthCallback(c *gin.Context) {
userID, _ := result.LastInsertId()
user = models.User{
ID: uint(userID),
Username: username,
Email: email,
Username: userInfo.Username,
Email: userInfo.Email,
Role: role,
}
} else if err != nil {