diff --git a/backend/database/db.go b/backend/database/db.go index b30e1ea..5f79a77 100644 --- a/backend/database/db.go +++ b/backend/database/db.go @@ -1,18 +1,19 @@ package database import ( - "database/sql" "fmt" "log" - _ "github.com/go-sql-driver/mysql" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" "aimodels-prices/config" "aimodels-prices/models" ) // DB 是数据库连接的全局实例 -var DB *sql.DB +var DB *gorm.DB // InitDB 初始化数据库连接 func InitDB(cfg *config.Config) error { @@ -28,59 +29,47 @@ func InitDB(cfg *config.Config) error { ) // 连接MySQL - DB, err = sql.Open("mysql", dsn) + DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) if err != nil { return fmt.Errorf("failed to connect to MySQL: %v", err) } - // 测试连接 - if err = DB.Ping(); err != nil { - return fmt.Errorf("failed to ping MySQL: %v", err) + // 获取底层的SQL DB + sqlDB, err := DB.DB() + if err != nil { + return fmt.Errorf("failed to get underlying SQL DB: %v", err) } // 设置连接池参数 - DB.SetMaxOpenConns(10) - DB.SetMaxIdleConns(5) + sqlDB.SetMaxOpenConns(10) + sqlDB.SetMaxIdleConns(5) - // 创建表结构 - if err = createTables(); err != nil { - return fmt.Errorf("failed to create tables: %v", err) + // 自动迁移表结构 + if err = migrateModels(); err != nil { + return fmt.Errorf("failed to migrate models: %v", err) } return nil } -// createTables 创建数据库表 -func createTables() error { - // 创建用户表 - if _, err := DB.Exec(models.CreateUserTableSQL()); err != nil { - log.Printf("Failed to create user table: %v", err) +// migrateModels 自动迁移模型到数据库表 +func migrateModels() error { + // 自动迁移模型 + if err := DB.AutoMigrate( + &models.ModelType{}, + &models.Price{}, + &models.Provider{}, + &models.User{}, + &models.Session{}, + ); err != nil { + log.Printf("Failed to migrate tables: %v", err) return err } - // 创建会话表 - if _, err := DB.Exec(models.CreateSessionTableSQL()); err != nil { - log.Printf("Failed to create session table: %v", err) - return err - } - - // 创建模型厂商表 - if _, err := DB.Exec(models.CreateProviderTableSQL()); err != nil { - log.Printf("Failed to create provider table: %v", err) - return err - } - - // 创建模型类型表 - if _, err := DB.Exec(models.CreateModelTypeTableSQL()); err != nil { - log.Printf("Failed to create model_type table: %v", err) - return err - } - - // 创建价格表 - if _, err := DB.Exec(models.CreatePriceTableSQL()); err != nil { - log.Printf("Failed to create price table: %v", err) - return err - } + // 这里可以添加其他模型的迁移 + // 例如:DB.AutoMigrate(&models.User{}) return nil } diff --git a/backend/go.mod b/backend/go.mod index d5c5a0f..6be0d07 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,14 +1,18 @@ module aimodels-prices -go 1.21 +go 1.23.0 + +toolchain go1.23.1 require ( github.com/gin-gonic/gin v1.9.1 - github.com/go-sql-driver/mysql v1.7.1 github.com/joho/godotenv v1.5.1 + gorm.io/driver/mysql v1.5.7 + gorm.io/gorm v1.25.12 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -16,7 +20,10 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-sql-driver/mysql v1.9.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect @@ -30,7 +37,7 @@ require ( golang.org/x/crypto v0.14.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.23.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 3e755f3..a292d23 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -7,8 +9,6 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -23,25 +23,23 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= +github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -49,8 +47,6 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -60,8 +56,6 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -82,20 +76,15 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= @@ -105,32 +94,9 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= -lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= -modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= -modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= -modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= -modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= -modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= -modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= -modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= -modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= -modernc.org/libc v1.29.0 h1:tTFRFq69YKCF2QyGNuRUQxKBm1uZZLubf6Cjh/pVHXs= -modernc.org/libc v1.29.0/go.mod h1:DaG/4Q3LRRdqpiLyP0C2m1B8ZMGkQ+cCgOIjEtQlYhQ= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= -modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= -modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= -modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.28.0 h1:Zx+LyDDmXczNnEQdvPuEfcFVA2ZPyaD7UCZDjef3BHQ= -modernc.org/sqlite v1.28.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0= -modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= -modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= -modernc.org/tcl v1.15.2 h1:C4ybAYCGJw968e+Me18oW55kD/FexcHbqH2xak1ROSY= -modernc.org/tcl v1.15.2/go.mod h1:3+k/ZaEbKrC8ePv8zJWPtBSW0V7Gg9g8rkmhI1Kfs3c= -modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= -modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -modernc.org/z v1.7.3 h1:zDJf6iHjrnB+WRD88stbXokugjyc0/pB91ri1gO6LZY= -modernc.org/z v1.7.3/go.mod h1:Ipv4tsdxZRbQyLq9Q1M6gdbkxYzdlrciF2Hi/lS7nWE= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/handlers/auth.go b/backend/handlers/auth.go index 1d8e5f6..1ab1324 100644 --- a/backend/handlers/auth.go +++ b/backend/handlers/auth.go @@ -2,7 +2,6 @@ package handlers import ( "crypto/rand" - "database/sql" "encoding/hex" "encoding/json" "fmt" @@ -13,9 +12,11 @@ import ( "github.com/gin-gonic/gin" + "aimodels-prices/database" "aimodels-prices/models" ) +// generateSessionID 生成随机会话ID func generateSessionID() string { b := make([]byte, 32) rand.Read(b) @@ -29,11 +30,8 @@ func GetAuthStatus(c *gin.Context) { return } - db := c.MustGet("db").(*sql.DB) var session models.Session - err = db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( - &session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt) - if err != nil { + if err := database.DB.Preload("User").Where("id = ?", cookie).First(&session).Error; err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"}) return } @@ -43,15 +41,9 @@ func GetAuthStatus(c *gin.Context) { return } - user, err := session.GetUser(db) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) - return - } - - c.Set("user", user) + c.Set("user", &session.User) c.JSON(http.StatusOK, gin.H{ - "user": user, + "user": session.User, }) } @@ -81,8 +73,8 @@ func Login(c *gin.Context) { func Logout(c *gin.Context) { cookie, err := c.Cookie("session") if err == nil { - db := c.MustGet("db").(*sql.DB) - db.Exec("DELETE FROM session WHERE id = ?", cookie) + // 删除会话 + database.DB.Where("id = ?", cookie).Delete(&models.Session{}) } c.SetCookie("session", "", -1, "/", "aimodels-prices.q58.club", true, true) @@ -96,22 +88,19 @@ func GetUser(c *gin.Context) { return } - db := c.MustGet("db").(*sql.DB) var session models.Session - if err := db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( - &session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt); err != nil { + if err := database.DB.Preload("User").Where("id = ?", cookie).First(&session).Error; err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"}) return } - user, err := session.GetUser(db) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) + if session.ExpiresAt.Before(time.Now()) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"}) return } c.JSON(http.StatusOK, gin.H{ - "user": user, + "user": session.User, }) } @@ -200,12 +189,9 @@ func AuthCallback(c *gin.Context) { // 添加调试日志 fmt.Printf("收到OAuth用户信息: ID=%v, Username=%s, Email=%s\n", userInfo.ID, userInfo.Username, userInfo.Email) - db := c.MustGet("db").(*sql.DB) - // 检查用户是否存在 var user models.User - err = db.QueryRow("SELECT id, username, email, role FROM user WHERE email = ?", userInfo.Email).Scan( - &user.ID, &user.Username, &user.Email, &user.Role) + result := database.DB.Where("email = ?", userInfo.Email).First(&user) role := "user" if userInfo.ID == 1 { // 这里写自己的用户ID @@ -215,44 +201,37 @@ func AuthCallback(c *gin.Context) { fmt.Printf("用户 %s (ID=%v) 不是管理员\n", userInfo.Username, userInfo.ID) } - if err == sql.ErrNoRows { + if result.Error != nil { // 创建新用户 - result, err := db.Exec(` - INSERT INTO user (username, email, role) - VALUES (?, ?, ?)`, - userInfo.Username, userInfo.Email, role) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) - return - } - userID, _ := result.LastInsertId() user = models.User{ - ID: uint(userID), Username: userInfo.Username, Email: userInfo.Email, Role: role, } - } else if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Database error"}) - return + if err := database.DB.Create(&user).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + return + } } else { // 更新现有用户的角色(如果需要) if user.Role != role { - _, err = db.Exec("UPDATE user SET role = ? WHERE id = ?", role, user.ID) - if err != nil { + user.Role = role + if err := database.DB.Save(&user).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user role"}) return } - user.Role = role } } // 创建会话 sessionID := generateSessionID() expiresAt := time.Now().Add(24 * time.Hour) - _, err = db.Exec("INSERT INTO session (id, user_id, expires_at) VALUES (?, ?, ?)", - sessionID, user.ID, expiresAt) - if err != nil { + session := models.Session{ + ID: sessionID, + UserID: user.ID, + ExpiresAt: expiresAt, + } + if err := database.DB.Create(&session).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"}) return } diff --git a/backend/handlers/model_type.go b/backend/handlers/model_type.go index 9737f31..4ba1876 100644 --- a/backend/handlers/model_type.go +++ b/backend/handlers/model_type.go @@ -1,55 +1,37 @@ package handlers import ( - "database/sql" "net/http" "github.com/gin-gonic/gin" + "aimodels-prices/database" "aimodels-prices/models" ) // GetModelTypes 获取所有模型类型 func GetModelTypes(c *gin.Context) { - db := c.MustGet("db").(*sql.DB) + var types []models.ModelType - rows, err := db.Query("SELECT type_key, type_label, sort_order FROM model_type ORDER BY sort_order ASC, type_key ASC") - if err != nil { + // 使用GORM查询所有模型类型,按排序字段和键值排序 + if err := database.DB.Order("sort_order ASC, type_key ASC").Find(&types).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - defer rows.Close() - - var types []models.ModelType - for rows.Next() { - var t models.ModelType - if err := rows.Scan(&t.TypeKey, &t.TypeLabel, &t.SortOrder); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - types = append(types, t) - } c.JSON(http.StatusOK, types) } // CreateModelType 添加新的模型类型 func CreateModelType(c *gin.Context) { - db := c.MustGet("db").(*sql.DB) - var newType models.ModelType if err := c.ShouldBindJSON(&newType); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - _, err := db.Exec(` - INSERT INTO model_type (type_key, type_label, sort_order) - VALUES (?, ?, ?) - ON DUPLICATE KEY UPDATE type_label = VALUES(type_label), sort_order = VALUES(sort_order) - `, newType.TypeKey, newType.TypeLabel, newType.SortOrder) - - if err != nil { + // 使用GORM创建新记录 + if err := database.DB.Create(&newType).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -59,58 +41,54 @@ func CreateModelType(c *gin.Context) { // UpdateModelType 更新模型类型 func UpdateModelType(c *gin.Context) { - db := c.MustGet("db").(*sql.DB) typeKey := c.Param("key") - var updateType models.ModelType if err := c.ShouldBindJSON(&updateType); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + // 查找现有记录 + var existingType models.ModelType + if err := database.DB.Where("type_key = ?", typeKey).First(&existingType).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Model type not found"}) + return + } + // 如果key发生变化,需要删除旧记录并创建新记录 if typeKey != updateType.TypeKey { - tx, err := db.Begin() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"}) - return - } + // 开始事务 + tx := database.DB.Begin() // 删除旧记录 - _, err = tx.Exec("DELETE FROM model_type WHERE type_key = ?", typeKey) - if err != nil { + if err := tx.Delete(&existingType).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old model type"}) return } // 创建新记录 - _, err = tx.Exec(` - INSERT INTO model_type (type_key, type_label, sort_order) - VALUES (?, ?, ?) - `, updateType.TypeKey, updateType.TypeLabel, updateType.SortOrder) - if err != nil { + if err := tx.Create(&updateType).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new model type"}) return } - if err := tx.Commit(); err != nil { + // 提交事务 + if err := tx.Commit().Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) return } } else { // 直接更新 - _, err := db.Exec(` - UPDATE model_type - SET type_label = ?, sort_order = ? - WHERE type_key = ? - `, updateType.TypeLabel, updateType.SortOrder, typeKey) - if err != nil { + existingType.TypeLabel = updateType.TypeLabel + existingType.SortOrder = updateType.SortOrder + if err := database.DB.Save(&existingType).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"}) return } + updateType = existingType } c.JSON(http.StatusOK, updateType) @@ -118,13 +96,18 @@ func UpdateModelType(c *gin.Context) { // DeleteModelType 删除模型类型 func DeleteModelType(c *gin.Context) { - db := c.MustGet("db").(*sql.DB) typeKey := c.Param("key") + // 查找现有记录 + var existingType models.ModelType + if err := database.DB.Where("type_key = ?", typeKey).First(&existingType).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Model type not found"}) + return + } + // 检查是否有价格记录使用此类型 - var count int - err := db.QueryRow("SELECT COUNT(*) FROM price WHERE model_type = ?", typeKey).Scan(&count) - if err != nil { + var count int64 + if err := database.DB.Model(&models.Price{}).Where("model_type = ?", typeKey).Count(&count).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model type usage"}) return } @@ -134,8 +117,8 @@ func DeleteModelType(c *gin.Context) { return } - _, err = db.Exec("DELETE FROM model_type WHERE type_key = ?", typeKey) - if err != nil { + // 删除记录 + if err := database.DB.Delete(&existingType).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete model type"}) return } diff --git a/backend/handlers/providers.go b/backend/handlers/providers.go index 107e40d..9b2d420 100644 --- a/backend/handlers/providers.go +++ b/backend/handlers/providers.go @@ -1,39 +1,24 @@ package handlers import ( - "database/sql" "net/http" "strconv" "time" "github.com/gin-gonic/gin" + "aimodels-prices/database" "aimodels-prices/models" ) // GetProviders 获取所有模型厂商 func GetProviders(c *gin.Context) { - db := c.MustGet("db").(*sql.DB) - rows, err := db.Query(` - SELECT id, name, icon, created_at, updated_at, created_by - FROM provider ORDER BY id`) - if err != nil { + var providers []models.Provider + + if err := database.DB.Order("id").Find(&providers).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch providers"}) return } - defer rows.Close() - - var providers []models.Provider - for rows.Next() { - var provider models.Provider - if err := rows.Scan( - &provider.ID, &provider.Name, &provider.Icon, - &provider.CreatedAt, &provider.UpdatedAt, &provider.CreatedBy); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan provider"}) - return - } - providers = append(providers, provider) - } c.JSON(http.StatusOK, providers) } @@ -47,10 +32,9 @@ func CreateProvider(c *gin.Context) { } // 检查ID是否已存在 - db := c.MustGet("db").(*sql.DB) - var existingID int - err := db.QueryRow("SELECT id FROM provider WHERE id = ?", provider.ID).Scan(&existingID) - if err != sql.ErrNoRows { + var existingProvider models.Provider + result := database.DB.Where("id = ?", provider.ID).First(&existingProvider) + if result.Error == nil { c.JSON(http.StatusBadRequest, gin.H{"error": "ID already exists"}) return } @@ -63,20 +47,15 @@ func CreateProvider(c *gin.Context) { } currentUser := user.(*models.User) - now := time.Now() - _, err = db.Exec(` - INSERT INTO provider (id, name, icon, created_at, updated_at, created_by) - VALUES (?, ?, ?, ?, ?, ?)`, - provider.ID, provider.Name, provider.Icon, now, now, currentUser.Username) - if err != nil { + // 设置创建者 + provider.CreatedBy = currentUser.Username + + // 创建记录 + if err := database.DB.Create(&provider).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create provider"}) return } - provider.CreatedAt = now - provider.UpdatedAt = now - provider.CreatedBy = currentUser.Username - c.JSON(http.StatusCreated, provider) } @@ -89,70 +68,63 @@ func UpdateProvider(c *gin.Context) { return } - db := c.MustGet("db").(*sql.DB) - - // 开始事务 - tx, err := db.Begin() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"}) + // 查找现有记录 + var existingProvider models.Provider + if err := database.DB.Where("id = ?", oldID).First(&existingProvider).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) return } // 如果ID发生变化,需要同时更新price表中的引用 if oldID != strconv.FormatUint(uint64(provider.ID), 10) { + // 开始事务 + tx := database.DB.Begin() + // 更新price表中的channel_type - _, err = tx.Exec("UPDATE price SET channel_type = ? WHERE channel_type = ?", provider.ID, oldID) - if err != nil { + if err := tx.Model(&models.Price{}).Where("channel_type = ?", oldID).Update("channel_type", provider.ID).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price references"}) return } // 更新price表中的temp_channel_type - _, err = tx.Exec("UPDATE price SET temp_channel_type = ? WHERE temp_channel_type = ?", provider.ID, oldID) - if err != nil { + if err := tx.Model(&models.Price{}).Where("temp_channel_type = ?", oldID).Update("temp_channel_type", provider.ID).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price temp references"}) return } // 删除旧记录 - _, err = tx.Exec("DELETE FROM provider WHERE id = ?", oldID) - if err != nil { + if err := tx.Delete(&existingProvider).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old provider"}) return } - // 插入新记录 - _, err = tx.Exec(` - INSERT INTO provider (id, name, icon, created_at, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - `, provider.ID, provider.Name, provider.Icon) - if err != nil { + // 创建新记录 + provider.CreatedAt = time.Now() + provider.UpdatedAt = time.Now() + if err := tx.Create(&provider).Error; err != nil { tx.Rollback() c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new provider"}) return } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) + return + } } else { // 如果ID没有变化,直接更新 - _, err = tx.Exec(` - UPDATE provider - SET name = ?, icon = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - `, provider.Name, provider.Icon, oldID) - if err != nil { - tx.Rollback() + existingProvider.Name = provider.Name + existingProvider.Icon = provider.Icon + if err := database.DB.Save(&existingProvider).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) return } - } - - // 提交事务 - if err := tx.Commit(); err != nil { - tx.Rollback() - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) - return + provider = existingProvider } c.JSON(http.StatusOK, provider) @@ -170,12 +142,11 @@ func UpdateProviderStatus(c *gin.Context) { return } - db := c.MustGet("db").(*sql.DB) now := time.Now() if input.Status == "approved" { // 如果是批准,将临时字段的值更新到正式字段 - _, err := db.Exec(` + result := database.DB.Exec(` UPDATE provider SET name = COALESCE(temp_name, name), icon = COALESCE(temp_icon, icon), @@ -185,13 +156,13 @@ func UpdateProviderStatus(c *gin.Context) { temp_icon = NULL, updated_by = NULL WHERE id = ?`, input.Status, now, id) - if err != nil { + if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"}) return } } else { // 如果是拒绝,清除临时字段 - _, err := db.Exec(` + result := database.DB.Exec(` UPDATE provider SET status = ?, updated_at = ?, @@ -199,7 +170,7 @@ func UpdateProviderStatus(c *gin.Context) { temp_icon = NULL, updated_by = NULL WHERE id = ?`, input.Status, now, id) - if err != nil { + if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider status"}) return } @@ -215,9 +186,28 @@ func UpdateProviderStatus(c *gin.Context) { // DeleteProvider 删除模型厂商 func DeleteProvider(c *gin.Context) { id := c.Param("id") - db := c.MustGet("db").(*sql.DB) - _, err := db.Exec("DELETE FROM provider WHERE id = ?", id) - if err != nil { + + // 查找现有记录 + var provider models.Provider + if err := database.DB.Where("id = ?", id).First(&provider).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Provider not found"}) + return + } + + // 检查是否有价格记录使用此厂商 + var count int64 + if err := database.DB.Model(&models.Price{}).Where("channel_type = ?", id).Count(&count).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check provider usage"}) + return + } + + if count > 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot delete provider that is in use"}) + return + } + + // 删除记录 + if err := database.DB.Delete(&provider).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete provider"}) return } diff --git a/backend/main.go b/backend/main.go index 6c46d34..f4f402e 100644 --- a/backend/main.go +++ b/backend/main.go @@ -23,7 +23,6 @@ func main() { if err := database.InitDB(cfg); err != nil { log.Fatalf("Failed to initialize database: %v", err) } - defer database.DB.Close() // 设置gin模式 if gin.Mode() == gin.ReleaseMode { @@ -32,12 +31,6 @@ func main() { r := gin.Default() - // 注入数据库 - r.Use(func(c *gin.Context) { - c.Set("db", database.DB) - c.Next() - }) - // CORS中间件 r.Use(func(c *gin.Context) { origin := c.Request.Header.Get("Origin") diff --git a/backend/middleware/auth.go b/backend/middleware/auth.go index b5fc917..cb32fe4 100644 --- a/backend/middleware/auth.go +++ b/backend/middleware/auth.go @@ -1,12 +1,12 @@ package middleware import ( - "database/sql" "net/http" "time" "github.com/gin-gonic/gin" + "aimodels-prices/database" "aimodels-prices/models" ) @@ -19,30 +19,14 @@ func AuthRequired() gin.HandlerFunc { return } - db := c.MustGet("db").(*sql.DB) var session models.Session - err = db.QueryRow("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session WHERE id = ?", cookie).Scan( - &session.ID, &session.UserID, &session.ExpiresAt, &session.CreatedAt, &session.UpdatedAt, &session.DeletedAt) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid session"}) + if err := database.DB.Preload("User").Where("id = ? AND expires_at > ?", cookie, time.Now()).First(&session).Error; err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired session"}) c.Abort() return } - if session.ExpiresAt.Before(time.Now()) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"}) - c.Abort() - return - } - - user, err := session.GetUser(db) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) - c.Abort() - return - } - - c.Set("user", user) + c.Set("user", &session.User) c.Next() } } diff --git a/backend/middleware/db.go b/backend/middleware/db.go deleted file mode 100644 index 10fad7f..0000000 --- a/backend/middleware/db.go +++ /dev/null @@ -1,14 +0,0 @@ -package middleware - -import ( - "aimodels-prices/database" - - "github.com/gin-gonic/gin" -) - -func Database() gin.HandlerFunc { - return func(c *gin.Context) { - c.Set("db", database.DB) - c.Next() - } -} diff --git a/backend/models/model_type.go b/backend/models/model_type.go index ba9e401..2708b29 100644 --- a/backend/models/model_type.go +++ b/backend/models/model_type.go @@ -1,18 +1,22 @@ package models +import ( + "time" + + "gorm.io/gorm" +) + // ModelType 模型类型结构 type ModelType struct { - TypeKey string `json:"key"` - TypeLabel string `json:"label"` - SortOrder int `json:"sort_order"` + TypeKey string `json:"key" gorm:"primaryKey;column:type_key"` + TypeLabel string `json:"label" gorm:"column:type_label;not null"` + SortOrder int `json:"sort_order" gorm:"column:sort_order;default:0"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } -// CreateModelTypeTableSQL 返回创建模型类型表的 SQL -func CreateModelTypeTableSQL() string { - return ` - CREATE TABLE IF NOT EXISTS model_type ( - type_key VARCHAR(50) PRIMARY KEY, - type_label VARCHAR(255) NOT NULL, - sort_order INT NOT NULL DEFAULT 0 - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` +// TableName 指定表名 +func (ModelType) TableName() string { + return "model_type" } diff --git a/backend/models/price.go b/backend/models/price.go index b10cc1b..6c6fddb 100644 --- a/backend/models/price.go +++ b/backend/models/price.go @@ -2,60 +2,38 @@ package models import ( "time" + + "gorm.io/gorm" ) type Price struct { - ID uint `json:"id"` - Model string `json:"model"` - ModelType string `json:"model_type"` // text2text, text2image, etc. - BillingType string `json:"billing_type"` // tokens or times - ChannelType string `json:"channel_type"` - Currency string `json:"currency"` // USD or CNY - InputPrice float64 `json:"input_price"` - OutputPrice float64 `json:"output_price"` - PriceSource string `json:"price_source"` - Status string `json:"status"` // pending, approved, rejected - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - CreatedBy string `json:"created_by"` + ID uint `json:"id" gorm:"primaryKey"` + Model string `json:"model" gorm:"not null"` + ModelType string `json:"model_type" gorm:"not null"` // text2text, text2image, etc. + BillingType string `json:"billing_type" gorm:"not null"` // tokens or times + ChannelType uint `json:"channel_type" gorm:"not null"` + Currency string `json:"currency" gorm:"not null"` // USD or CNY + InputPrice float64 `json:"input_price" gorm:"not null"` + OutputPrice float64 `json:"output_price" gorm:"not null"` + PriceSource string `json:"price_source" gorm:"not null"` + Status string `json:"status" gorm:"not null;default:pending"` // pending, approved, rejected + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + CreatedBy string `json:"created_by" gorm:"not null"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` // 临时字段,用于存储待审核的更新 - TempModel *string `json:"temp_model,omitempty"` - TempModelType *string `json:"temp_model_type,omitempty"` - TempBillingType *string `json:"temp_billing_type,omitempty"` - TempChannelType *string `json:"temp_channel_type,omitempty"` - TempCurrency *string `json:"temp_currency,omitempty"` - TempInputPrice *float64 `json:"temp_input_price,omitempty"` - TempOutputPrice *float64 `json:"temp_output_price,omitempty"` - TempPriceSource *string `json:"temp_price_source,omitempty"` - UpdatedBy *string `json:"updated_by,omitempty"` + TempModel *string `json:"temp_model,omitempty" gorm:"column:temp_model"` + TempModelType *string `json:"temp_model_type,omitempty" gorm:"column:temp_model_type"` + TempBillingType *string `json:"temp_billing_type,omitempty" gorm:"column:temp_billing_type"` + TempChannelType *uint `json:"temp_channel_type,omitempty" gorm:"column:temp_channel_type"` + TempCurrency *string `json:"temp_currency,omitempty" gorm:"column:temp_currency"` + TempInputPrice *float64 `json:"temp_input_price,omitempty" gorm:"column:temp_input_price"` + TempOutputPrice *float64 `json:"temp_output_price,omitempty" gorm:"column:temp_output_price"` + TempPriceSource *string `json:"temp_price_source,omitempty" gorm:"column:temp_price_source"` + UpdatedBy *string `json:"updated_by,omitempty" gorm:"column:updated_by"` } -// CreatePriceTableSQL 返回创建价格表的 SQL -func CreatePriceTableSQL() string { - return ` - CREATE TABLE IF NOT EXISTS price ( - id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - model VARCHAR(255) NOT NULL, - model_type VARCHAR(50) NOT NULL, - billing_type VARCHAR(50) NOT NULL, - channel_type BIGINT UNSIGNED NOT NULL, - currency VARCHAR(10) NOT NULL, - input_price DECIMAL(10,6) NOT NULL, - output_price DECIMAL(10,6) NOT NULL, - price_source VARCHAR(255) NOT NULL, - status VARCHAR(50) NOT NULL DEFAULT 'pending', - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - created_by VARCHAR(255) NOT NULL, - temp_model VARCHAR(255), - temp_model_type VARCHAR(50), - temp_billing_type VARCHAR(50), - temp_channel_type BIGINT UNSIGNED, - temp_currency VARCHAR(10), - temp_input_price DECIMAL(10,6), - temp_output_price DECIMAL(10,6), - temp_price_source VARCHAR(255), - updated_by VARCHAR(255), - FOREIGN KEY (channel_type) REFERENCES provider(id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` +// TableName 指定表名 +func (Price) TableName() string { + return "price" } diff --git a/backend/models/provider.go b/backend/models/provider.go index 83fd494..6925646 100644 --- a/backend/models/provider.go +++ b/backend/models/provider.go @@ -1,25 +1,22 @@ package models -import "time" +import ( + "time" + + "gorm.io/gorm" +) type Provider struct { - ID uint `json:"id"` - Name string `json:"name"` - Icon string `json:"icon"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - CreatedBy string `json:"created_by"` + ID uint `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"not null"` + Icon string `json:"icon"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + CreatedBy string `json:"created_by" gorm:"not null"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } -// CreateProviderTableSQL 返回创建模型厂商表的 SQL -func CreateProviderTableSQL() string { - return ` - CREATE TABLE IF NOT EXISTS provider ( - id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - name VARCHAR(255) NOT NULL, - icon VARCHAR(1024), - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - created_by VARCHAR(255) NOT NULL - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` +// TableName 指定表名 +func (Provider) TableName() string { + return "provider" } diff --git a/backend/models/user.go b/backend/models/user.go index 495ce64..e6f42ae 100644 --- a/backend/models/user.go +++ b/backend/models/user.go @@ -1,64 +1,37 @@ package models import ( - "database/sql" "time" + + "gorm.io/gorm" ) type User struct { - ID uint `json:"id"` - Username string `json:"username"` - Email string `json:"email"` - Role string `json:"role"` // admin or user - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - DeletedAt *time.Time `json:"deleted_at,omitempty"` + ID uint `json:"id" gorm:"primaryKey"` + Username string `json:"username" gorm:"not null;unique"` + Email string `json:"email" gorm:"not null"` + Role string `json:"role" gorm:"not null;default:user"` // admin or user + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } type Session struct { - ID string `json:"id"` - UserID uint `json:"user_id"` - ExpiresAt time.Time `json:"expires_at"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - DeletedAt *time.Time `json:"deleted_at,omitempty"` + ID string `json:"id" gorm:"primaryKey"` + UserID uint `json:"user_id" gorm:"not null"` + User User `json:"user" gorm:"foreignKey:UserID"` + ExpiresAt time.Time `json:"expires_at" gorm:"not null"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } -// CreateUserTableSQL 返回创建用户表的 SQL -func CreateUserTableSQL() string { - return ` - CREATE TABLE IF NOT EXISTS user ( - id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, - username VARCHAR(255) UNIQUE NOT NULL, - email VARCHAR(255) UNIQUE NOT NULL, - role VARCHAR(50) NOT NULL DEFAULT 'user', - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - deleted_at TIMESTAMP NULL - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` +// TableName 指定User表名 +func (User) TableName() string { + return "user" } -// CreateSessionTableSQL 返回创建会话表的 SQL -func CreateSessionTableSQL() string { - return ` - CREATE TABLE IF NOT EXISTS session ( - id VARCHAR(255) PRIMARY KEY, - user_id BIGINT UNSIGNED NOT NULL, - expires_at TIMESTAMP NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - deleted_at TIMESTAMP NULL, - FOREIGN KEY (user_id) REFERENCES user(id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` -} - -// GetUser 获取会话关联的用户 -func (s *Session) GetUser(db *sql.DB) (*User, error) { - var user User - err := db.QueryRow("SELECT id, username, email, role, created_at, updated_at, deleted_at FROM user WHERE id = ?", s.UserID).Scan( - &user.ID, &user.Username, &user.Email, &user.Role, &user.CreatedAt, &user.UpdatedAt, &user.DeletedAt) - if err != nil { - return nil, err - } - return &user, nil +// TableName 指定Session表名 +func (Session) TableName() string { + return "session" } diff --git a/backend/router/router.go b/backend/router/router.go deleted file mode 100644 index 791be3a..0000000 --- a/backend/router/router.go +++ /dev/null @@ -1,37 +0,0 @@ -package router - -import ( - "github.com/gin-gonic/gin" - - "aimodels-prices/handlers" - "aimodels-prices/middleware" -) - -// SetupRouter 设置路由 -func SetupRouter() *gin.Engine { - r := gin.Default() - - // 添加数据库中间件 - r.Use(middleware.Database()) - - // 认证相关路由 - auth := r.Group("/auth") - { - auth.GET("/status", handlers.GetAuthStatus) - auth.POST("/login", handlers.Login) - auth.POST("/logout", handlers.Logout) - } - - // 模型厂商相关路由 - providers := r.Group("/providers") - { - providers.GET("", handlers.GetProviders) - providers.Use(middleware.RequireAuth()) - providers.Use(middleware.RequireAdmin()) - providers.POST("", handlers.CreateProvider) - providers.PUT("/:id", handlers.UpdateProvider) - providers.DELETE("/:id", handlers.DeleteProvider) - } - - return r -}