mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
289 lines
7.9 KiB
Go
289 lines
7.9 KiB
Go
package handlers
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"aimodels-prices/models"
|
|
)
|
|
|
|
func generateSessionID() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func GetAuthStatus(c *gin.Context) {
|
|
cookie, err := c.Cookie("session")
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Not logged in"})
|
|
return
|
|
}
|
|
|
|
db := c.MustGet("db").(*sql.DB)
|
|
var session models.Session
|
|
err = db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan(
|
|
&session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
|
|
return
|
|
}
|
|
|
|
if session.ExpiresAt.Before(time.Now()) {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"})
|
|
return
|
|
}
|
|
|
|
user, err := session.GetUser(db)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
|
return
|
|
}
|
|
|
|
c.Set("user", user)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"user": user,
|
|
})
|
|
}
|
|
|
|
func Login(c *gin.Context) {
|
|
// 开发环境下使用测试账号
|
|
if gin.Mode() != gin.ReleaseMode {
|
|
db := c.MustGet("db").(*sql.DB)
|
|
|
|
// 创建测试用户(如果不存在)
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM user WHERE username = 'admin'").Scan(&count)
|
|
if err != nil || count == 0 {
|
|
_, err = db.Exec("INSERT INTO user (username, email, role) VALUES (?, ?, ?)",
|
|
"admin", "admin@test.com", "admin")
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create test user"})
|
|
return
|
|
}
|
|
}
|
|
|
|
// 获取用户ID
|
|
var userID uint
|
|
err = db.QueryRow("SELECT id FROM user WHERE username = 'admin'").Scan(&userID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
|
return
|
|
}
|
|
|
|
// 创建会话
|
|
sessionID := generateSessionID()
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
_, err = db.Exec("INSERT INTO session (id, user_id, expires_at) VALUES (?, ?, ?)",
|
|
sessionID, userID, expiresAt)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
|
|
return
|
|
}
|
|
|
|
// 设置cookie
|
|
c.SetCookie("session", sessionID, int(24*time.Hour.Seconds()), "/", "aimodels-prices.q58.pro", true, true)
|
|
c.JSON(http.StatusOK, gin.H{"message": "Logged in successfully"})
|
|
return
|
|
}
|
|
|
|
// 生产环境使用 OAuth 2.0
|
|
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
|
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
|
|
authorizeURL := os.Getenv("OAUTH_AUTHORIZE_URL")
|
|
|
|
if clientID == "" || redirectURI == "" || authorizeURL == "" {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "OAuth configuration not found"})
|
|
return
|
|
}
|
|
|
|
// 构建授权 URL
|
|
authURL := fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s",
|
|
authorizeURL,
|
|
url.QueryEscape(clientID),
|
|
url.QueryEscape(redirectURI))
|
|
|
|
// 返回授权 URL 而不是直接重定向
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"auth_url": authURL,
|
|
})
|
|
}
|
|
|
|
func Logout(c *gin.Context) {
|
|
cookie, err := c.Cookie("session")
|
|
if err == nil {
|
|
db := c.MustGet("db").(*sql.DB)
|
|
db.Exec("DELETE FROM session WHERE id = ?", cookie)
|
|
}
|
|
|
|
c.SetCookie("session", "", -1, "/", "aimodels-prices.q58.pro", true, true)
|
|
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
|
}
|
|
|
|
func GetUser(c *gin.Context) {
|
|
cookie, err := c.Cookie("session")
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Not logged in"})
|
|
return
|
|
}
|
|
|
|
db := c.MustGet("db").(*sql.DB)
|
|
var session models.Session
|
|
if err := db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan(
|
|
&session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt); err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
|
|
return
|
|
}
|
|
|
|
user, err := session.GetUser(db)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"user": user,
|
|
})
|
|
}
|
|
|
|
func AuthCallback(c *gin.Context) {
|
|
code := c.Query("code")
|
|
if code == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing authorization code"})
|
|
return
|
|
}
|
|
|
|
// 获取访问令牌
|
|
tokenURL := os.Getenv("OAUTH_TOKEN_URL")
|
|
clientID := os.Getenv("OAUTH_CLIENT_ID")
|
|
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
|
|
redirectURI := os.Getenv("OAUTH_REDIRECT_URI")
|
|
|
|
// 构建请求体
|
|
data := url.Values{}
|
|
data.Set("code", code)
|
|
data.Set("client_id", clientID)
|
|
data.Set("client_secret", clientSecret)
|
|
data.Set("redirect_uri", redirectURI)
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
// 发送请求获取访问令牌
|
|
resp, err := http.PostForm(tokenURL, data)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get access token"})
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse token response"})
|
|
return
|
|
}
|
|
|
|
// 使用访问令牌获取用户信息
|
|
userURL := os.Getenv("OAUTH_USER_URL")
|
|
req, err := http.NewRequest("GET", userURL, nil)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user info request"})
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
client := &http.Client{}
|
|
userResp, err := client.Do(req)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
|
|
return
|
|
}
|
|
defer userResp.Body.Close()
|
|
|
|
var userInfo struct {
|
|
ID string `json:"id"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
Admin bool `json:"admin"`
|
|
AvatarURL string `json:"avatar_url"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
if err := json.NewDecoder(userResp.Body).Decode(&userInfo); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse user info"})
|
|
return
|
|
}
|
|
|
|
db := c.MustGet("db").(*sql.DB)
|
|
|
|
// 检查用户是否存在
|
|
var user models.User
|
|
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan(
|
|
&user.ID, &user.Username, &user.Email, &user.Role)
|
|
|
|
role := "user"
|
|
if userInfo.Admin {
|
|
role = "admin"
|
|
}
|
|
|
|
if err == sql.ErrNoRows {
|
|
// 创建新用户
|
|
result, err := db.Exec(`
|
|
INSERT INTO user (username, email, role)
|
|
VALUES (?, ?, ?)`,
|
|
userInfo.Username, userInfo.Email, role)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
|
return
|
|
}
|
|
userID, _ := result.LastInsertId()
|
|
user = models.User{
|
|
ID: uint(userID),
|
|
Username: userInfo.Username,
|
|
Email: userInfo.Email,
|
|
Role: role,
|
|
}
|
|
} else if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Database error"})
|
|
return
|
|
} else {
|
|
// 更新现有用户的角色(如果需要)
|
|
if user.Role != role {
|
|
_, err = db.Exec("UPDATE user SET role = ? WHERE id = ?", role, user.ID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user role"})
|
|
return
|
|
}
|
|
user.Role = role
|
|
}
|
|
}
|
|
|
|
// 创建会话
|
|
sessionID := generateSessionID()
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
_, err = db.Exec("INSERT INTO session (id, user_id, expires_at) VALUES (?, ?, ?)",
|
|
sessionID, user.ID, expiresAt)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
|
|
return
|
|
}
|
|
|
|
// 设置 cookie
|
|
c.SetCookie("session", sessionID, int(24*time.Hour.Seconds()), "/", "aimodels-prices.q58.pro", true, true)
|
|
|
|
// 重定向到前端
|
|
c.Redirect(http.StatusTemporaryRedirect, "https://aimodels-prices.q58.pro")
|
|
}
|