mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
Migrate authentication from Discourse SSO to OAuth 2.0 authentication flow
This commit is contained in:
parent
8a2aec1ca3
commit
bd8656542b
8
.env.example
Normal file
8
.env.example
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
PORT=8080
|
||||||
|
GIN_MODE=debug
|
||||||
|
OAUTH_CLIENT_ID=your_client_id_here
|
||||||
|
OAUTH_CLIENT_SECRET=your_client_secret_here
|
||||||
|
OAUTH_REDIRECT_URI=https://aimodels-prices.q58.pro/api/auth/callback
|
||||||
|
OAUTH_AUTHORIZE_URL=https://connect.q58.pro/oauth/authorize
|
||||||
|
OAUTH_TOKEN_URL=https://connect.q58.pro/api/oauth/access_token
|
||||||
|
OAUTH_USER_URL=https://connect.q58.pro/api/oauth/user
|
@ -1,17 +1,14 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -99,45 +96,24 @@ func Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生产环境使用 Discourse SSO
|
// 生产环境使用 OAuth 2.0
|
||||||
discourseURL := os.Getenv("DISCOURSE_URL")
|
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
||||||
if discourseURL == "" {
|
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Discourse URL not configured"})
|
authorizeURL := os.Getenv("OAUTH_AUTHORIZE_URL")
|
||||||
|
|
||||||
|
if clientID == "" || redirectURI == "" || authorizeURL == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "OAuth configuration not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成随机 nonce
|
// 构建授权 URL
|
||||||
nonce := make([]byte, 16)
|
authURL := fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s",
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
authorizeURL,
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate nonce"})
|
url.QueryEscape(clientID),
|
||||||
return
|
url.QueryEscape(redirectURI))
|
||||||
}
|
|
||||||
nonceStr := hex.EncodeToString(nonce)
|
|
||||||
|
|
||||||
// 构建 payload
|
// 重定向到授权页面
|
||||||
payload := url.Values{}
|
c.Redirect(http.StatusTemporaryRedirect, authURL)
|
||||||
payload.Set("nonce", nonceStr)
|
|
||||||
payload.Set("return_sso_url", fmt.Sprintf("https://aimodels-prices.q58.pro/api/auth/callback"))
|
|
||||||
|
|
||||||
// Base64 编码
|
|
||||||
payloadStr := base64.StdEncoding.EncodeToString([]byte(payload.Encode()))
|
|
||||||
|
|
||||||
// 计算签名
|
|
||||||
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET")
|
|
||||||
if ssoSecret == "" {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h := hmac.New(sha256.New, []byte(ssoSecret))
|
|
||||||
h.Write([]byte(payloadStr))
|
|
||||||
sig := hex.EncodeToString(h.Sum(nil))
|
|
||||||
|
|
||||||
// 构建重定向 URL
|
|
||||||
redirectURL := fmt.Sprintf("%s/session/sso?sso=%s&sig=%s",
|
|
||||||
discourseURL, url.QueryEscape(payloadStr), sig)
|
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Logout(c *gin.Context) {
|
func Logout(c *gin.Context) {
|
||||||
@ -178,75 +154,94 @@ func GetUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AuthCallback(c *gin.Context) {
|
func AuthCallback(c *gin.Context) {
|
||||||
sso := c.Query("sso")
|
code := c.Query("code")
|
||||||
sig := c.Query("sig")
|
if code == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing authorization code"})
|
||||||
if sso == "" || sig == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing parameters"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 SSO 密钥
|
// 获取访问令牌
|
||||||
ssoSecret := os.Getenv("DISCOURSE_SSO_SECRET")
|
tokenURL := os.Getenv("OAUTH_TOKEN_URL")
|
||||||
if ssoSecret == "" {
|
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "SSO secret not configured"})
|
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
|
||||||
return
|
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
|
||||||
}
|
|
||||||
|
|
||||||
// 验证签名
|
// 构建请求体
|
||||||
h := hmac.New(sha256.New, []byte(ssoSecret))
|
data := url.Values{}
|
||||||
h.Write([]byte(sso))
|
data.Set("code", code)
|
||||||
computedSig := hex.EncodeToString(h.Sum(nil))
|
data.Set("client_id", clientID)
|
||||||
if computedSig != sig {
|
data.Set("client_secret", clientSecret)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid signature"})
|
data.Set("redirect_uri", redirectURI)
|
||||||
return
|
data.Set("grant_type", "authorization_code")
|
||||||
}
|
|
||||||
|
|
||||||
// 解码 SSO payload
|
// 发送请求获取访问令牌
|
||||||
payload, err := base64.StdEncoding.DecodeString(sso)
|
resp, err := http.PostForm(tokenURL, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid SSO payload"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get access token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse token response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 payload
|
// 使用访问令牌获取用户信息
|
||||||
values, err := url.ParseQuery(string(payload))
|
userURL := os.Getenv("OAUTH_USER_URL")
|
||||||
|
req, err := http.NewRequest("GET", userURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid payload format"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user info request"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取用户信息
|
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||||
username := values.Get("username")
|
client := &http.Client{}
|
||||||
email := values.Get("email")
|
userResp, err := client.Do(req)
|
||||||
groups := values.Get("groups")
|
if err != nil {
|
||||||
admin := values.Get("admin") // Discourse 管理员标志
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
|
||||||
moderator := values.Get("moderator") // Discourse 版主标志
|
|
||||||
if username == "" || email == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing user information"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer userResp.Body.Close()
|
||||||
|
|
||||||
// 判断用户角色
|
var userInfo struct {
|
||||||
role := "user"
|
ID string `json:"id"`
|
||||||
// 如果是管理员、版主或属于 admins 组,都赋予管理权限
|
Email string `json:"email"`
|
||||||
if admin == "true" || moderator == "true" || (groups != "" && strings.Contains(groups, "admins")) {
|
Username string `json:"username"`
|
||||||
role = "admin"
|
Admin bool `json:"admin"`
|
||||||
|
AvatarURL string `json:"avatar_url"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(userResp.Body).Decode(&userInfo); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse user info"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := c.MustGet("db").(*sql.DB)
|
db := c.MustGet("db").(*sql.DB)
|
||||||
|
|
||||||
// 检查用户是否存在
|
// 检查用户是否存在
|
||||||
var user models.User
|
var user models.User
|
||||||
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", email).Scan(
|
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan(
|
||||||
&user.ID, &user.Username, &user.Email, &user.Role)
|
&user.ID, &user.Username, &user.Email, &user.Role)
|
||||||
|
|
||||||
|
role := "user"
|
||||||
|
if userInfo.Admin {
|
||||||
|
role = "admin"
|
||||||
|
}
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// 创建新用户
|
// 创建新用户
|
||||||
result, err := db.Exec(`
|
result, err := db.Exec(`
|
||||||
INSERT INTO user (username, email, role)
|
INSERT INTO user (username, email, role)
|
||||||
VALUES (?, ?, ?)`,
|
VALUES (?, ?, ?)`,
|
||||||
username, email, role)
|
userInfo.Username, userInfo.Email, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||||
return
|
return
|
||||||
@ -254,8 +249,8 @@ func AuthCallback(c *gin.Context) {
|
|||||||
userID, _ := result.LastInsertId()
|
userID, _ := result.LastInsertId()
|
||||||
user = models.User{
|
user = models.User{
|
||||||
ID: uint(userID),
|
ID: uint(userID),
|
||||||
Username: username,
|
Username: userInfo.Username,
|
||||||
Email: email,
|
Email: userInfo.Email,
|
||||||
Role: role,
|
Role: role,
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user