重构数据库层并迁移到 GORM

- 将数据库操作从原生 SQL 迁移到 GORM ORM
- 更新模型定义,添加 GORM 标签和关系
- 移除手动创建表的 SQL 方法,改用 AutoMigrate
- 更新所有数据库相关处理逻辑以适配 GORM
- 升级 Go 版本和依赖库
- 移除数据库和路由中间件,简化项目结构
This commit is contained in:
wood chen 2025-03-06 23:32:18 +08:00
parent 0bdadcfef7
commit aeb05f790a
14 changed files with 263 additions and 471 deletions

View File

@ -1,18 +1,19 @@
package database package database
import ( import (
"database/sql"
"fmt" "fmt"
"log" "log"
_ "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"aimodels-prices/config" "aimodels-prices/config"
"aimodels-prices/models" "aimodels-prices/models"
) )
// DB 是数据库连接的全局实例 // DB 是数据库连接的全局实例
var DB *sql.DB var DB *gorm.DB
// InitDB 初始化数据库连接 // InitDB 初始化数据库连接
func InitDB(cfg *config.Config) error { func InitDB(cfg *config.Config) error {
@ -28,59 +29,47 @@ func InitDB(cfg *config.Config) error {
) )
// 连接MySQL // 连接MySQL
DB, err = sql.Open("mysql", dsn) DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to MySQL: %v", err) return fmt.Errorf("failed to connect to MySQL: %v", err)
} }
// 测试连接 // 获取底层的SQL DB
if err = DB.Ping(); err != nil { sqlDB, err := DB.DB()
return fmt.Errorf("failed to ping MySQL: %v", err) if err != nil {
return fmt.Errorf("failed to get underlying SQL DB: %v", err)
} }
// 设置连接池参数 // 设置连接池参数
DB.SetMaxOpenConns(10) sqlDB.SetMaxOpenConns(10)
DB.SetMaxIdleConns(5) sqlDB.SetMaxIdleConns(5)
// 创建表结构 // 自动迁移表结构
if err = createTables(); err != nil { if err = migrateModels(); err != nil {
return fmt.Errorf("failed to create tables: %v", err) return fmt.Errorf("failed to migrate models: %v", err)
} }
return nil return nil
} }
// createTables 创建数据库表 // migrateModels 自动迁移模型到数据库表
func createTables() error { func migrateModels() error {
// 创建用户表 // 自动迁移模型
if _, err := DB.Exec(models.CreateUserTableSQL()); err != nil { if err := DB.AutoMigrate(
log.Printf("Failed to create user table: %v", err) &models.ModelType{},
&models.Price{},
&models.Provider{},
&models.User{},
&models.Session{},
); err != nil {
log.Printf("Failed to migrate tables: %v", err)
return err return err
} }
// 创建会话表 // 这里可以添加其他模型的迁移
if _, err := DB.Exec(models.CreateSessionTableSQL()); err != nil { // 例如DB.AutoMigrate(&models.User{})
log.Printf("Failed to create session table: %v", err)
return err
}
// 创建模型厂商表
if _, err := DB.Exec(models.CreateProviderTableSQL()); err != nil {
log.Printf("Failed to create provider table: %v", err)
return err
}
// 创建模型类型表
if _, err := DB.Exec(models.CreateModelTypeTableSQL()); err != nil {
log.Printf("Failed to create model_type table: %v", err)
return err
}
// 创建价格表
if _, err := DB.Exec(models.CreatePriceTableSQL()); err != nil {
log.Printf("Failed to create price table: %v", err)
return err
}
return nil return nil
} }

View File

@ -1,14 +1,18 @@
module aimodels-prices module aimodels-prices
go 1.21 go 1.23.0
toolchain go1.23.1
require ( require (
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-sql-driver/mysql v1.7.1
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
gorm.io/driver/mysql v1.5.7
gorm.io/gorm v1.25.12
) )
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@ -16,7 +20,10 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.9.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
@ -30,7 +37,7 @@ require (
golang.org/x/crypto v0.14.0 // indirect golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@ -1,3 +1,5 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@ -7,8 +9,6 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@ -23,25 +23,23 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@ -49,8 +47,6 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -60,8 +56,6 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -82,20 +76,15 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
@ -105,32 +94,9 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ=
modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM=
modernc.org/libc v1.29.0 h1:tTFRFq69YKCF2QyGNuRUQxKBm1uZZLubf6Cjh/pVHXs=
modernc.org/libc v1.29.0/go.mod h1:DaG/4Q3LRRdqpiLyP0C2m1B8ZMGkQ+cCgOIjEtQlYhQ=
modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4=
modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo=
modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E=
modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E=
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
modernc.org/sqlite v1.28.0 h1:Zx+LyDDmXczNnEQdvPuEfcFVA2ZPyaD7UCZDjef3BHQ=
modernc.org/sqlite v1.28.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0=
modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
modernc.org/tcl v1.15.2 h1:C4ybAYCGJw968e+Me18oW55kD/FexcHbqH2xak1ROSY=
modernc.org/tcl v1.15.2/go.mod h1:3+k/ZaEbKrC8ePv8zJWPtBSW0V7Gg9g8rkmhI1Kfs3c=
modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg=
modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
modernc.org/z v1.7.3 h1:zDJf6iHjrnB+WRD88stbXokugjyc0/pB91ri1gO6LZY=
modernc.org/z v1.7.3/go.mod h1:Ipv4tsdxZRbQyLq9Q1M6gdbkxYzdlrciF2Hi/lS7nWE=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@ -2,7 +2,6 @@ package handlers
import ( import (
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,9 +12,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models" "aimodels-prices/models"
) )
// generateSessionID 生成随机会话ID
func generateSessionID() string { func generateSessionID() string {
b := make([]byte, 32) b := make([]byte, 32)
rand.Read(b) rand.Read(b)
@ -29,11 +30,8 @@ func GetAuthStatus(c *gin.Context) {
return return
} }
db := c.MustGet("db").(*sql.DB)
var session models.Session var session models.Session
err = db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( if err := database.DB.Preload("User").Where("id = ?", cookie).First(&session).Error; err != nil {
&session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
return return
} }
@ -43,15 +41,9 @@ func GetAuthStatus(c *gin.Context) {
return return
} }
user, err := session.GetUser(db) c.Set("user", &session.User)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
c.Set("user", user)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"user": user, "user": session.User,
}) })
} }
@ -81,8 +73,8 @@ func Login(c *gin.Context) {
func Logout(c *gin.Context) { func Logout(c *gin.Context) {
cookie, err := c.Cookie("session") cookie, err := c.Cookie("session")
if err == nil { if err == nil {
db := c.MustGet("db").(*sql.DB) // 删除会话
db.Exec("DELETE FROM session WHERE id = ?", cookie) database.DB.Where("id = ?", cookie).Delete(&models.Session{})
} }
c.SetCookie("session", "", -1, "/", "aimodels-prices.q58.club", true, true) c.SetCookie("session", "", -1, "/", "aimodels-prices.q58.club", true, true)
@ -96,22 +88,19 @@ func GetUser(c *gin.Context) {
return return
} }
db := c.MustGet("db").(*sql.DB)
var session models.Session var session models.Session
if err := db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( if err := database.DB.Preload("User").Where("id = ?", cookie).First(&session).Error; err != nil {
&session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
return return
} }
user, err := session.GetUser(db) if session.ExpiresAt.Before(time.Now()) {
if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"user": user, "user": session.User,
}) })
} }
@ -200,12 +189,9 @@ func AuthCallback(c *gin.Context) {
// 添加调试日志 // 添加调试日志
fmt.Printf("收到OAuth用户信息: ID=%v, Username=%s, Email=%s\n", userInfo.ID, userInfo.Username, userInfo.Email) fmt.Printf("收到OAuth用户信息: ID=%v, Username=%s, Email=%s\n", userInfo.ID, userInfo.Username, userInfo.Email)
db := c.MustGet("db").(*sql.DB)
// 检查用户是否存在 // 检查用户是否存在
var user models.User var user models.User
err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan( result := database.DB.Where("email = ?", userInfo.Email).First(&user)
&user.ID, &user.Username, &user.Email, &user.Role)
role := "user" role := "user"
if userInfo.ID == 1 { // 这里写自己的用户ID if userInfo.ID == 1 { // 这里写自己的用户ID
@ -215,44 +201,37 @@ func AuthCallback(c *gin.Context) {
fmt.Printf("用户 %s (ID=%v) 不是管理员\n", userInfo.Username, userInfo.ID) fmt.Printf("用户 %s (ID=%v) 不是管理员\n", userInfo.Username, userInfo.ID)
} }
if err == sql.ErrNoRows { if result.Error != nil {
// 创建新用户 // 创建新用户
result, err := db.Exec(`
INSERT INTO user (username, email, role)
VALUES (?, ?, ?)`,
userInfo.Username, userInfo.Email, role)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
userID, _ := result.LastInsertId()
user = models.User{ user = models.User{
ID: uint(userID),
Username: userInfo.Username, Username: userInfo.Username,
Email: userInfo.Email, Email: userInfo.Email,
Role: role, Role: role,
} }
} else if err != nil { if err := database.DB.Create(&user).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Database error"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return return
}
} else { } else {
// 更新现有用户的角色(如果需要) // 更新现有用户的角色(如果需要)
if user.Role != role { if user.Role != role {
_, err = db.Exec("UPDATE user SET role = ? WHERE id = ?", role, user.ID) user.Role = role
if err != nil { if err := database.DB.Save(&user).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user role"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user role"})
return return
} }
user.Role = role
} }
} }
// 创建会话 // 创建会话
sessionID := generateSessionID() sessionID := generateSessionID()
expiresAt := time.Now().Add(24 * time.Hour) expiresAt := time.Now().Add(24 * time.Hour)
_, err = db.Exec("INSERT INTO session (id, user_id, expires_at) VALUES (?, ?, ?)", session := models.Session{
sessionID, user.ID, expiresAt) ID: sessionID,
if err != nil { UserID: user.ID,
ExpiresAt: expiresAt,
}
if err := database.DB.Create(&session).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
return return
} }

View File

@ -1,55 +1,37 @@
package handlers package handlers
import ( import (
"database/sql"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models" "aimodels-prices/models"
) )
// GetModelTypes 获取所有模型类型 // GetModelTypes 获取所有模型类型
func GetModelTypes(c *gin.Context) { func GetModelTypes(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
rows, err := db.Query("SELECT type_key, type_label, sort_order FROM model_type ORDER BY sort_order ASC, type_key ASC")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer rows.Close()
var types []models.ModelType var types []models.ModelType
for rows.Next() {
var t models.ModelType // 使用GORM查询所有模型类型按排序字段和键值排序
if err := rows.Scan(&t.TypeKey, &t.TypeLabel, &t.SortOrder); err != nil { if err := database.DB.Order("sort_order ASC, type_key ASC").Find(&types).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
types = append(types, t)
}
c.JSON(http.StatusOK, types) c.JSON(http.StatusOK, types)
} }
// CreateModelType 添加新的模型类型 // CreateModelType 添加新的模型类型
func CreateModelType(c *gin.Context) { func CreateModelType(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
var newType models.ModelType var newType models.ModelType
if err := c.ShouldBindJSON(&newType); err != nil { if err := c.ShouldBindJSON(&newType); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
_, err := db.Exec(` // 使用GORM创建新记录
INSERT INTO model_type (type_key, type_label, sort_order) if err := database.DB.Create(&newType).Error; err != nil {
VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE type_label = VALUES(type_label), sort_order = VALUES(sort_order)
`, newType.TypeKey, newType.TypeLabel, newType.SortOrder)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -59,58 +41,54 @@ func CreateModelType(c *gin.Context) {
// UpdateModelType 更新模型类型 // UpdateModelType 更新模型类型
func UpdateModelType(c *gin.Context) { func UpdateModelType(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
typeKey := c.Param("key") typeKey := c.Param("key")
var updateType models.ModelType var updateType models.ModelType
if err := c.ShouldBindJSON(&updateType); err != nil { if err := c.ShouldBindJSON(&updateType); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
// 如果key发生变化需要删除旧记录并创建新记录 // 查找现有记录
if typeKey != updateType.TypeKey { var existingType models.ModelType
tx, err := db.Begin() if err := database.DB.Where("type_key = ?", typeKey).First(&existingType).Error; err != nil {
if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "Model type not found"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"})
return return
} }
// 如果key发生变化需要删除旧记录并创建新记录
if typeKey != updateType.TypeKey {
// 开始事务
tx := database.DB.Begin()
// 删除旧记录 // 删除旧记录
_, err = tx.Exec("DELETE FROM model_type WHERE type_key = ?", typeKey) if err := tx.Delete(&existingType).Error; err != nil {
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old model type"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old model type"})
return return
} }
// 创建新记录 // 创建新记录
_, err = tx.Exec(` if err := tx.Create(&updateType).Error; err != nil {
INSERT INTO model_type (type_key, type_label, sort_order)
VALUES (?, ?, ?)
`, updateType.TypeKey, updateType.TypeLabel, updateType.SortOrder)
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new model type"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new model type"})
return return
} }
if err := tx.Commit(); err != nil { // 提交事务
if err := tx.Commit().Error; err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
return return
} }
} else { } else {
// 直接更新 // 直接更新
_, err := db.Exec(` existingType.TypeLabel = updateType.TypeLabel
UPDATE model_type existingType.SortOrder = updateType.SortOrder
SET type_label = ?, sort_order = ? if err := database.DB.Save(&existingType).Error; err != nil {
WHERE type_key = ?
`, updateType.TypeLabel, updateType.SortOrder, typeKey)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"})
return return
} }
updateType = existingType
} }
c.JSON(http.StatusOK, updateType) c.JSON(http.StatusOK, updateType)
@ -118,13 +96,18 @@ func UpdateModelType(c *gin.Context) {
// DeleteModelType 删除模型类型 // DeleteModelType 删除模型类型
func DeleteModelType(c *gin.Context) { func DeleteModelType(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
typeKey := c.Param("key") typeKey := c.Param("key")
// 查找现有记录
var existingType models.ModelType
if err := database.DB.Where("type_key = ?", typeKey).First(&existingType).Error; err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Model type not found"})
return
}
// 检查是否有价格记录使用此类型 // 检查是否有价格记录使用此类型
var count int var count int64
err := db.QueryRow("SELECT COUNT(*) FROM price WHERE model_type = ?", typeKey).Scan(&count) if err := database.DB.Model(&models.Price{}).Where("model_type = ?", typeKey).Count(&count).Error; err != nil {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model type usage"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model type usage"})
return return
} }
@ -134,8 +117,8 @@ func DeleteModelType(c *gin.Context) {
return return
} }
_, err = db.Exec("DELETE FROM model_type WHERE type_key = ?", typeKey) // 删除记录
if err != nil { if err := database.DB.Delete(&existingType).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete model type"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete model type"})
return return
} }

View File

@ -1,39 +1,24 @@
package handlers package handlers
import ( import (
"database/sql"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models" "aimodels-prices/models"
) )
// GetProviders 获取所有模型厂商 // GetProviders 获取所有模型厂商
func GetProviders(c *gin.Context) { func GetProviders(c *gin.Context) {
db := c.MustGet("db").(*sql.DB) var providers []models.Provider
rows, err := db.Query(`
SELECT id, name, icon, created_at, updated_at, created_by if err := database.DB.Order("id").Find(&providers).Error; err != nil {
FROM provider ORDER BY id`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch providers"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch providers"})
return return
} }
defer rows.Close()
var providers []models.Provider
for rows.Next() {
var provider models.Provider
if err := rows.Scan(
&provider.ID, &provider.Name, &provider.Icon,
&provider.CreatedAt, &provider.UpdatedAt, &provider.CreatedBy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan provider"})
return
}
providers = append(providers, provider)
}
c.JSON(http.StatusOK, providers) c.JSON(http.StatusOK, providers)
} }
@ -47,10 +32,9 @@ func CreateProvider(c *gin.Context) {
} }
// 检查ID是否已存在 // 检查ID是否已存在
db := c.MustGet("db").(*sql.DB) var existingProvider models.Provider
var existingID int result := database.DB.Where("id = ?", provider.ID).First(&existingProvider)
err := db.QueryRow("SELECT id FROM provider WHERE id = ?", provider.ID).Scan(&existingID) if result.Error == nil {
if err != sql.ErrNoRows {
c.JSON(http.StatusBadRequest, gin.H{"error": "ID already exists"}) c.JSON(http.StatusBadRequest, gin.H{"error": "ID already exists"})
return return
} }
@ -63,20 +47,15 @@ func CreateProvider(c *gin.Context) {
} }
currentUser := user.(*models.User) currentUser := user.(*models.User)
now := time.Now() // 设置创建者
_, err = db.Exec(` provider.CreatedBy = currentUser.Username
INSERT INTO provider (id, name, icon, created_at, updated_at, created_by)
VALUES (?, ?, ?, ?, ?, ?)`, // 创建记录
provider.ID, provider.Name, provider.Icon, now, now, currentUser.Username) if err := database.DB.Create(&provider).Error; err != nil {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"})
return return
} }
provider.CreatedAt = now
provider.UpdatedAt = now
provider.CreatedBy = currentUser.Username
c.JSON(http.StatusCreated, provider) c.JSON(http.StatusCreated, provider)
} }
@ -89,71 +68,64 @@ func UpdateProvider(c *gin.Context) {
return return
} }
db := c.MustGet("db").(*sql.DB) // 查找现有记录
var existingProvider models.Provider
// 开始事务 if err := database.DB.Where("id = ?", oldID).First(&existingProvider).Error; err != nil {
tx, err := db.Begin() c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"})
return return
} }
// 如果ID发生变化需要同时更新price表中的引用 // 如果ID发生变化需要同时更新price表中的引用
if oldID != strconv.FormatUint(uint64(provider.ID), 10) { if oldID != strconv.FormatUint(uint64(provider.ID), 10) {
// 开始事务
tx := database.DB.Begin()
// 更新price表中的channel_type // 更新price表中的channel_type
_, err = tx.Exec("UPDATE price SET channel_type = ? WHERE channel_type = ?", provider.ID, oldID) if err := tx.Model(&models.Price{}).Where("channel_type = ?", oldID).Update("channel_type", provider.ID).Error; err != nil {
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price references"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price references"})
return return
} }
// 更新price表中的temp_channel_type // 更新price表中的temp_channel_type
_, err = tx.Exec("UPDATE price SET temp_channel_type = ? WHERE temp_channel_type = ?", provider.ID, oldID) if err := tx.Model(&models.Price{}).Where("temp_channel_type = ?", oldID).Update("temp_channel_type", provider.ID).Error; err != nil {
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price temp references"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price temp references"})
return return
} }
// 删除旧记录 // 删除旧记录
_, err = tx.Exec("DELETE FROM provider WHERE id = ?", oldID) if err := tx.Delete(&existingProvider).Error; err != nil {
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old provider"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old provider"})
return return
} }
// 插入新记录 // 创建新记录
_, err = tx.Exec(` provider.CreatedAt = time.Now()
INSERT INTO provider (id, name, icon, created_at, updated_at) provider.UpdatedAt = time.Now()
VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) if err := tx.Create(&provider).Error; err != nil {
`, provider.ID, provider.Name, provider.Icon)
if err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new provider"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new provider"})
return return
} }
} else {
// 如果ID没有变化直接更新
_, err = tx.Exec(`
UPDATE provider
SET name = ?, icon = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, provider.Name, provider.Icon, oldID)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
}
// 提交事务 // 提交事务
if err := tx.Commit(); err != nil { if err := tx.Commit().Error; err != nil {
tx.Rollback() tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
return return
} }
} else {
// 如果ID没有变化直接更新
existingProvider.Name = provider.Name
existingProvider.Icon = provider.Icon
if err := database.DB.Save(&existingProvider).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
provider = existingProvider
}
c.JSON(http.StatusOK, provider) c.JSON(http.StatusOK, provider)
} }
@ -170,12 +142,11 @@ func UpdateProviderStatus(c *gin.Context) {
return return
} }
db := c.MustGet("db").(*sql.DB)
now := time.Now() now := time.Now()
if input.Status == "approved" { if input.Status == "approved" {
// 如果是批准,将临时字段的值更新到正式字段 // 如果是批准,将临时字段的值更新到正式字段
_, err := db.Exec(` result := database.DB.Exec(`
UPDATE provider UPDATE provider
SET name = COALESCE(temp_name, name), SET name = COALESCE(temp_name, name),
icon = COALESCE(temp_icon, icon), icon = COALESCE(temp_icon, icon),
@ -185,13 +156,13 @@ func UpdateProviderStatus(c *gin.Context) {
temp_icon = NULL, temp_icon = NULL,
updated_by = NULL updated_by = NULL
WHERE id = ?`, input.Status, now, id) WHERE id = ?`, input.Status, now, id)
if err != nil { if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"})
return return
} }
} else { } else {
// 如果是拒绝,清除临时字段 // 如果是拒绝,清除临时字段
_, err := db.Exec(` result := database.DB.Exec(`
UPDATE provider UPDATE provider
SET status = ?, SET status = ?,
updated_at = ?, updated_at = ?,
@ -199,7 +170,7 @@ func UpdateProviderStatus(c *gin.Context) {
temp_icon = NULL, temp_icon = NULL,
updated_by = NULL updated_by = NULL
WHERE id = ?`, input.Status, now, id) WHERE id = ?`, input.Status, now, id)
if err != nil { if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"})
return return
} }
@ -215,9 +186,28 @@ func UpdateProviderStatus(c *gin.Context) {
// DeleteProvider 删除模型厂商 // DeleteProvider 删除模型厂商
func DeleteProvider(c *gin.Context) { func DeleteProvider(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
db := c.MustGet("db").(*sql.DB)
_, err := db.Exec("DELETE FROM provider WHERE id = ?", id) // 查找现有记录
if err != nil { var provider models.Provider
if err := database.DB.Where("id = ?", id).First(&provider).Error; err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"})
return
}
// 检查是否有价格记录使用此厂商
var count int64
if err := database.DB.Model(&models.Price{}).Where("channel_type = ?", id).Count(&count).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check provider usage"})
return
}
if count > 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot delete provider that is in use"})
return
}
// 删除记录
if err := database.DB.Delete(&provider).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"})
return return
} }

View File

@ -23,7 +23,6 @@ func main() {
if err := database.InitDB(cfg); err != nil { if err := database.InitDB(cfg); err != nil {
log.Fatalf("Failed to initialize database: %v", err) log.Fatalf("Failed to initialize database: %v", err)
} }
defer database.DB.Close()
// 设置gin模式 // 设置gin模式
if gin.Mode() == gin.ReleaseMode { if gin.Mode() == gin.ReleaseMode {
@ -32,12 +31,6 @@ func main() {
r := gin.Default() r := gin.Default()
// 注入数据库
r.Use(func(c *gin.Context) {
c.Set("db", database.DB)
c.Next()
})
// CORS中间件 // CORS中间件
r.Use(func(c *gin.Context) { r.Use(func(c *gin.Context) {
origin := c.Request.Header.Get("Origin") origin := c.Request.Header.Get("Origin")

View File

@ -1,12 +1,12 @@
package middleware package middleware
import ( import (
"database/sql"
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models" "aimodels-prices/models"
) )
@ -19,30 +19,14 @@ func AuthRequired() gin.HandlerFunc {
return return
} }
db := c.MustGet("db").(*sql.DB)
var session models.Session var session models.Session
err = db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( if err := database.DB.Preload("User").Where("id = ? AND expires_at > ?", cookie, time.Now()).First(&session).Error; err != nil {
&session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired session"})
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"})
c.Abort() c.Abort()
return return
} }
if session.ExpiresAt.Before(time.Now()) { c.Set("user", &session.User)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"})
c.Abort()
return
}
user, err := session.GetUser(db)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
c.Abort()
return
}
c.Set("user", user)
c.Next() c.Next()
} }
} }

View File

@ -1,14 +0,0 @@
package middleware
import (
"aimodels-prices/database"
"github.com/gin-gonic/gin"
)
func Database() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("db", database.DB)
c.Next()
}
}

View File

@ -1,18 +1,22 @@
package models package models
import (
"time"
"gorm.io/gorm"
)
// ModelType 模型类型结构 // ModelType 模型类型结构
type ModelType struct { type ModelType struct {
TypeKey string `json:"key"` TypeKey string `json:"key" gorm:"primaryKey;column:type_key"`
TypeLabel string `json:"label"` TypeLabel string `json:"label" gorm:"column:type_label;not null"`
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order" gorm:"column:sort_order;default:0"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
} }
// CreateModelTypeTableSQL 返回创建模型类型表的 SQL // TableName 指定表名
func CreateModelTypeTableSQL() string { func (ModelType) TableName() string {
return ` return "model_type"
CREATE TABLE IF NOT EXISTS model_type (
type_key VARCHAR(50) PRIMARY KEY,
type_label VARCHAR(255) NOT NULL,
sort_order INT NOT NULL DEFAULT 0
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
} }

View File

@ -2,60 +2,38 @@ package models
import ( import (
"time" "time"
"gorm.io/gorm"
) )
type Price struct { type Price struct {
ID uint `json:"id"` ID uint `json:"id" gorm:"primaryKey"`
Model string `json:"model"` Model string `json:"model" gorm:"not null"`
ModelType string `json:"model_type"` // text2text, text2image, etc. ModelType string `json:"model_type" gorm:"not null"` // text2text, text2image, etc.
BillingType string `json:"billing_type"` // tokens or times BillingType string `json:"billing_type" gorm:"not null"` // tokens or times
ChannelType string `json:"channel_type"` ChannelType uint `json:"channel_type" gorm:"not null"`
Currency string `json:"currency"` // USD or CNY Currency string `json:"currency" gorm:"not null"` // USD or CNY
InputPrice float64 `json:"input_price"` InputPrice float64 `json:"input_price" gorm:"not null"`
OutputPrice float64 `json:"output_price"` OutputPrice float64 `json:"output_price" gorm:"not null"`
PriceSource string `json:"price_source"` PriceSource string `json:"price_source" gorm:"not null"`
Status string `json:"status"` // pending, approved, rejected Status string `json:"status" gorm:"not null;default:pending"` // pending, approved, rejected
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
CreatedBy string `json:"created_by"` CreatedBy string `json:"created_by" gorm:"not null"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 临时字段,用于存储待审核的更新 // 临时字段,用于存储待审核的更新
TempModel *string `json:"temp_model,omitempty"` TempModel *string `json:"temp_model,omitempty" gorm:"column:temp_model"`
TempModelType *string `json:"temp_model_type,omitempty"` TempModelType *string `json:"temp_model_type,omitempty" gorm:"column:temp_model_type"`
TempBillingType *string `json:"temp_billing_type,omitempty"` TempBillingType *string `json:"temp_billing_type,omitempty" gorm:"column:temp_billing_type"`
TempChannelType *string `json:"temp_channel_type,omitempty"` TempChannelType *uint `json:"temp_channel_type,omitempty" gorm:"column:temp_channel_type"`
TempCurrency *string `json:"temp_currency,omitempty"` TempCurrency *string `json:"temp_currency,omitempty" gorm:"column:temp_currency"`
TempInputPrice *float64 `json:"temp_input_price,omitempty"` TempInputPrice *float64 `json:"temp_input_price,omitempty" gorm:"column:temp_input_price"`
TempOutputPrice *float64 `json:"temp_output_price,omitempty"` TempOutputPrice *float64 `json:"temp_output_price,omitempty" gorm:"column:temp_output_price"`
TempPriceSource *string `json:"temp_price_source,omitempty"` TempPriceSource *string `json:"temp_price_source,omitempty" gorm:"column:temp_price_source"`
UpdatedBy *string `json:"updated_by,omitempty"` UpdatedBy *string `json:"updated_by,omitempty" gorm:"column:updated_by"`
} }
// CreatePriceTableSQL 返回创建价格表的 SQL // TableName 指定表名
func CreatePriceTableSQL() string { func (Price) TableName() string {
return ` return "price"
CREATE TABLE IF NOT EXISTS price (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
model VARCHAR(255) NOT NULL,
model_type VARCHAR(50) NOT NULL,
billing_type VARCHAR(50) NOT NULL,
channel_type BIGINT UNSIGNED NOT NULL,
currency VARCHAR(10) NOT NULL,
input_price DECIMAL(10,6) NOT NULL,
output_price DECIMAL(10,6) NOT NULL,
price_source VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL DEFAULT 'pending',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
created_by VARCHAR(255) NOT NULL,
temp_model VARCHAR(255),
temp_model_type VARCHAR(50),
temp_billing_type VARCHAR(50),
temp_channel_type BIGINT UNSIGNED,
temp_currency VARCHAR(10),
temp_input_price DECIMAL(10,6),
temp_output_price DECIMAL(10,6),
temp_price_source VARCHAR(255),
updated_by VARCHAR(255),
FOREIGN KEY (channel_type) REFERENCES provider(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
} }

View File

@ -1,25 +1,22 @@
package models package models
import "time" import (
"time"
"gorm.io/gorm"
)
type Provider struct { type Provider struct {
ID uint `json:"id"` ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"` Name string `json:"name" gorm:"not null"`
Icon string `json:"icon"` Icon string `json:"icon"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
CreatedBy string `json:"created_by"` CreatedBy string `json:"created_by" gorm:"not null"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
} }
// CreateProviderTableSQL 返回创建模型厂商表的 SQL // TableName 指定表名
func CreateProviderTableSQL() string { func (Provider) TableName() string {
return ` return "provider"
CREATE TABLE IF NOT EXISTS provider (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL,
icon VARCHAR(1024),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
created_by VARCHAR(255) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
} }

View File

@ -1,64 +1,37 @@
package models package models
import ( import (
"database/sql"
"time" "time"
"gorm.io/gorm"
) )
type User struct { type User struct {
ID uint `json:"id"` ID uint `json:"id" gorm:"primaryKey"`
Username string `json:"username"` Username string `json:"username" gorm:"not null;unique"`
Email string `json:"email"` Email string `json:"email" gorm:"not null"`
Role string `json:"role"` // admin or user Role string `json:"role" gorm:"not null;default:user"` // admin or user
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
DeletedAt *time.Time `json:"deleted_at,omitempty"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
} }
type Session struct { type Session struct {
ID string `json:"id"` ID string `json:"id" gorm:"primaryKey"`
UserID uint `json:"user_id"` UserID uint `json:"user_id" gorm:"not null"`
ExpiresAt time.Time `json:"expires_at"` User User `json:"user" gorm:"foreignKey:UserID"`
CreatedAt time.Time `json:"created_at"` ExpiresAt time.Time `json:"expires_at" gorm:"not null"`
UpdatedAt time.Time `json:"updated_at"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
DeletedAt *time.Time `json:"deleted_at,omitempty"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
} }
// CreateUserTableSQL 返回创建用户表的 SQL // TableName 指定User表名
func CreateUserTableSQL() string { func (User) TableName() string {
return ` return "user"
CREATE TABLE IF NOT EXISTS user (
id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'user',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
} }
// CreateSessionTableSQL 返回创建会话表的 SQL // TableName 指定Session表名
func CreateSessionTableSQL() string { func (Session) TableName() string {
return ` return "session"
CREATE TABLE IF NOT EXISTS session (
id VARCHAR(255) PRIMARY KEY,
user_id BIGINT UNSIGNED NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL,
FOREIGN KEY (user_id) REFERENCES user(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
}
// GetUser 获取会话关联的用户
func (s *Session) GetUser(db *sql.DB) (*User, error) {
var user User
err := db.QueryRow("SELECT id, username, email, role, created_at, updated_at, deleted_at FROM user WHERE id = ?", s.UserID).Scan(
&user.ID, &user.Username, &user.Email, &user.Role, &user.CreatedAt, &user.UpdatedAt, &user.DeletedAt)
if err != nil {
return nil, err
}
return &user, nil
} }

View File

@ -1,37 +0,0 @@
package router
import (
"github.com/gin-gonic/gin"
"aimodels-prices/handlers"
"aimodels-prices/middleware"
)
// SetupRouter 设置路由
func SetupRouter() *gin.Engine {
r := gin.Default()
// 添加数据库中间件
r.Use(middleware.Database())
// 认证相关路由
auth := r.Group("/auth")
{
auth.GET("/status", handlers.GetAuthStatus)
auth.POST("/login", handlers.Login)
auth.POST("/logout", handlers.Logout)
}
// 模型厂商相关路由
providers := r.Group("/providers")
{
providers.GET("", handlers.GetProviders)
providers.Use(middleware.RequireAuth())
providers.Use(middleware.RequireAdmin())
providers.POST("", handlers.CreateProvider)
providers.PUT("/:id", handlers.UpdateProvider)
providers.DELETE("/:id", handlers.DeleteProvider)
}
return r
}