diff --git a/Dockerfile b/Dockerfile index 221f6a7..631a15c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,16 +24,6 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \ rm main-* && \ chmod +x main -# 复制迁移工具 -COPY backend/migrate-* ./ -RUN if [ "$(uname -m)" = "aarch64" ]; then \ - cp migrate-arm64 migrate; \ - else \ - cp migrate-amd64 migrate; \ - fi && \ - rm migrate-* && \ - chmod +x migrate - COPY frontend/dist /app/frontend COPY backend/config/nginx.conf /etc/nginx/nginx.conf COPY scripts/start.sh ./ diff --git a/backend/cmd/migrate/main.go b/backend/cmd/migrate/main.go deleted file mode 100644 index 60639d1..0000000 --- a/backend/cmd/migrate/main.go +++ /dev/null @@ -1,238 +0,0 @@ -package main - -import ( - "database/sql" - "fmt" - "log" - "os" - - "aimodels-prices/config" - "aimodels-prices/database" -) - -func main() { - // 加载配置 - cfg, err := config.LoadConfig() - if err != nil { - log.Fatalf("Failed to load config: %v", err) - } - - // 检查SQLite数据库文件是否存在 - if _, err := os.Stat(cfg.SQLitePath); os.IsNotExist(err) { - log.Printf("SQLite database file not found at %s, skipping migration", cfg.SQLitePath) - os.Exit(0) - } - - // 连接SQLite数据库 - sqliteDB, err := database.InitSQLiteDB(cfg.SQLitePath) - if err != nil { - log.Fatalf("Failed to connect to SQLite database: %v", err) - } - defer sqliteDB.Close() - - // 初始化MySQL数据库 - if err := database.InitDB(cfg); err != nil { - log.Fatalf("Failed to initialize MySQL database: %v", err) - } - defer database.DB.Close() - - // 开始迁移数据 - if err := migrateData(sqliteDB, database.DB); err != nil { - log.Fatalf("Failed to migrate data: %v", err) - } - - log.Println("Data migration completed successfully!") -} - -func migrateData(sqliteDB *sql.DB, mysqlDB *sql.DB) error { - // 迁移用户数据 - if err := migrateUsers(sqliteDB, mysqlDB); err != nil { - return fmt.Errorf("failed to migrate users: %v", err) - } - - // 迁移会话数据 - if err := migrateSessions(sqliteDB, mysqlDB); err != nil { - return fmt.Errorf("failed to migrate sessions: %v", err) - } - - // 迁移提供商数据 - if err := migrateProviders(sqliteDB, mysqlDB); err != nil { - return fmt.Errorf("failed to migrate providers: %v", err) - } - - // 迁移价格数据 - if err := migratePrices(sqliteDB, mysqlDB); err != nil { - return fmt.Errorf("failed to migrate prices: %v", err) - } - - return nil -} - -func migrateUsers(sqliteDB *sql.DB, mysqlDB *sql.DB) error { - rows, err := sqliteDB.Query("SELECT id, username, email, role, created_at, updated_at, deleted_at FROM user") - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var ( - id uint - username string - email string - role string - createdAt string - updatedAt string - deletedAt sql.NullString - ) - - if err := rows.Scan(&id, &username, &email, &role, &createdAt, &updatedAt, &deletedAt); err != nil { - return err - } - - _, err = mysqlDB.Exec( - "INSERT INTO user (id, username, email, role, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, username, email, role, createdAt, updatedAt, deletedAt.String, - ) - if err != nil { - return err - } - } - - return rows.Err() -} - -func migrateSessions(sqliteDB *sql.DB, mysqlDB *sql.DB) error { - rows, err := sqliteDB.Query("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session") - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var ( - id string - userID uint - expiresAt string - createdAt string - updatedAt string - deletedAt sql.NullString - ) - - if err := rows.Scan(&id, &userID, &expiresAt, &createdAt, &updatedAt, &deletedAt); err != nil { - return err - } - - _, err = mysqlDB.Exec( - "INSERT INTO session (id, user_id, expires_at, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?)", - id, userID, expiresAt, createdAt, updatedAt, deletedAt.String, - ) - if err != nil { - return err - } - } - - return rows.Err() -} - -func migrateProviders(sqliteDB *sql.DB, mysqlDB *sql.DB) error { - rows, err := sqliteDB.Query("SELECT id, name, icon, created_at, updated_at, created_by FROM provider") - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var ( - id uint - name string - icon sql.NullString - createdAt string - updatedAt string - createdBy string - ) - - if err := rows.Scan(&id, &name, &icon, &createdAt, &updatedAt, &createdBy); err != nil { - return err - } - - _, err = mysqlDB.Exec( - "INSERT INTO provider (id, name, icon, created_at, updated_at, created_by) VALUES (?, ?, ?, ?, ?, ?)", - id, name, icon.String, createdAt, updatedAt, createdBy, - ) - if err != nil { - return err - } - } - - return rows.Err() -} - -func migratePrices(sqliteDB *sql.DB, mysqlDB *sql.DB) error { - rows, err := sqliteDB.Query(` - SELECT id, model, model_type, billing_type, channel_type, currency, - input_price, output_price, price_source, status, created_at, updated_at, - created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type, - temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by - FROM price - `) - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var ( - id uint - model string - modelType string - billingType string - channelType uint - currency string - inputPrice float64 - outputPrice float64 - priceSource string - status string - createdAt string - updatedAt string - createdBy string - tempModel sql.NullString - tempModelType sql.NullString - tempBillingType sql.NullString - tempChannelType sql.NullInt64 - tempCurrency sql.NullString - tempInputPrice sql.NullFloat64 - tempOutputPrice sql.NullFloat64 - tempPriceSource sql.NullString - updatedBy sql.NullString - ) - - if err := rows.Scan( - &id, &model, &modelType, &billingType, &channelType, ¤cy, - &inputPrice, &outputPrice, &priceSource, &status, &createdAt, &updatedAt, - &createdBy, &tempModel, &tempModelType, &tempBillingType, &tempChannelType, - &tempCurrency, &tempInputPrice, &tempOutputPrice, &tempPriceSource, &updatedBy, - ); err != nil { - return err - } - - _, err = mysqlDB.Exec(` - INSERT INTO price ( - id, model, model_type, billing_type, channel_type, currency, - input_price, output_price, price_source, status, created_at, updated_at, - created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type, - temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, - id, model, modelType, billingType, channelType, currency, - inputPrice, outputPrice, priceSource, status, createdAt, updatedAt, - createdBy, tempModel.String, tempModelType.String, tempBillingType.String, tempChannelType.Int64, - tempCurrency.String, tempInputPrice.Float64, tempOutputPrice.Float64, tempPriceSource.String, updatedBy.String, - ) - if err != nil { - return err - } - } - - return rows.Err() -} diff --git a/backend/config/config.go b/backend/config/config.go index f2e6d6f..30134c9 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -18,9 +18,6 @@ type Config struct { // 其他配置 ServerPort string - - // SQLite配置(用于数据迁移) - SQLitePath string } func LoadConfig() (*Config, error) { @@ -50,9 +47,6 @@ func LoadConfig() (*Config, error) { // 其他配置 ServerPort: getEnv("PORT", "8080"), - - // SQLite路径(用于数据迁移) - SQLitePath: filepath.Join(dbDir, "aimodels.db"), } return config, nil diff --git a/backend/database/db.go b/backend/database/db.go index 7f24500..b30e1ea 100644 --- a/backend/database/db.go +++ b/backend/database/db.go @@ -6,7 +6,6 @@ import ( "log" _ "github.com/go-sql-driver/mysql" - _ "modernc.org/sqlite" "aimodels-prices/config" "aimodels-prices/models" @@ -51,20 +50,6 @@ func InitDB(cfg *config.Config) error { return nil } -// InitSQLiteDB 初始化SQLite数据库连接(用于数据迁移) -func InitSQLiteDB(dbPath string) (*sql.DB, error) { - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, fmt.Errorf("failed to connect to SQLite: %v", err) - } - - if err = db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping SQLite: %v", err) - } - - return db, nil -} - // createTables 创建数据库表 func createTables() error { // 创建用户表 @@ -85,6 +70,12 @@ func createTables() error { return err } + // 创建模型类型表 + if _, err := DB.Exec(models.CreateModelTypeTableSQL()); err != nil { + log.Printf("Failed to create model_type table: %v", err) + return err + } + // 创建价格表 if _, err := DB.Exec(models.CreatePriceTableSQL()); err != nil { log.Printf("Failed to create price table: %v", err) diff --git a/backend/go.mod b/backend/go.mod index a40d875..d5c5a0f 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -6,47 +6,31 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/go-sql-driver/mysql v1.7.1 github.com/joho/godotenv v1.5.1 - modernc.org/sqlite v1.28.0 ) require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.14.0 // indirect - golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.6.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - lukechampine.com/uint128 v1.2.0 // indirect - modernc.org/cc/v3 v3.40.0 // indirect - modernc.org/ccgo/v3 v3.16.13 // indirect - modernc.org/libc v1.29.0 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.7.2 // indirect - modernc.org/opt v0.1.3 // indirect - modernc.org/strutil v1.1.3 // indirect - modernc.org/token v1.0.1 // indirect ) diff --git a/backend/models/model_type.go b/backend/models/model_type.go new file mode 100644 index 0000000..967119c --- /dev/null +++ b/backend/models/model_type.go @@ -0,0 +1,16 @@ +package models + +// ModelType 模型类型结构 +type ModelType struct { + TypeKey string `json:"key"` + TypeLabel string `json:"label"` +} + +// CreateModelTypeTableSQL 返回创建模型类型表的 SQL +func CreateModelTypeTableSQL() string { + return ` + CREATE TABLE IF NOT EXISTS model_type ( + type_key VARCHAR(50) PRIMARY KEY, + type_label VARCHAR(255) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` +} diff --git a/backend/scripts/init.sql b/backend/scripts/init.sql new file mode 100644 index 0000000..a3e9855 --- /dev/null +++ b/backend/scripts/init.sql @@ -0,0 +1,10 @@ +-- 初始化模型类型数据 +INSERT INTO model_type (type_key, type_label) VALUES +('text2text', '文生文'), +('text2image', '文生图'), +('text2speech', '文生音'), +('speech2text', '音生文'), +('image2text', '图生文'), +('embedding', '向量'), +('other', '其他') +ON DUPLICATE KEY UPDATE type_label = VALUES(type_label); \ No newline at end of file diff --git a/backend/scripts/migrate.go b/backend/scripts/migrate.go deleted file mode 100644 index a8ebf05..0000000 --- a/backend/scripts/migrate.go +++ /dev/null @@ -1,174 +0,0 @@ -package main - -import ( - "database/sql" - "log" - "os" - "path/filepath" - "strings" - - _ "modernc.org/sqlite" -) - -// ModelType 模型类型结构 -type ModelType struct { - Key string `json:"key"` - Label string `json:"label"` -} - -func main() { - // 确保数据目录存在 - dbDir := "./data" - if err := os.MkdirAll(dbDir, 0755); err != nil { - log.Fatalf("创建数据目录失败: %v", err) - } - - // 连接数据库 - dbPath := filepath.Join(dbDir, "aimodels.db") - db, err := sql.Open("sqlite", dbPath) - if err != nil { - log.Fatalf("连接数据库失败: %v", err) - } - defer db.Close() - - // 创建model_type表 - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS model_type ( - key TEXT PRIMARY KEY, - label TEXT NOT NULL - ) - `) - if err != nil { - log.Fatalf("创建model_type表失败: %v", err) - } - - // 初始化默认的模型类型 - defaultTypes := []ModelType{ - {Key: "text2text", Label: "文生文"}, - {Key: "text2image", Label: "文生图"}, - {Key: "text2speech", Label: "文生音"}, - {Key: "speech2text", Label: "音生文"}, - {Key: "image2text", Label: "图生文"}, - {Key: "embedding", Label: "向量"}, - {Key: "other", Label: "其他"}, - } - - // 插入默认类型 - for _, t := range defaultTypes { - _, err = db.Exec(` - INSERT OR IGNORE INTO model_type (key, label) - VALUES (?, ?) - `, t.Key, t.Label) - if err != nil { - log.Printf("插入默认类型失败 %s: %v", t.Key, err) - } - } - - // 检查model_type列是否存在 - var hasModelType bool - err = db.QueryRow(` - SELECT COUNT(*) > 0 - FROM pragma_table_info('price') - WHERE name = 'model_type' - `).Scan(&hasModelType) - if err != nil { - log.Fatalf("检查model_type列失败: %v", err) - } - - // 如果model_type列不存在,则添加它 - if !hasModelType { - log.Println("开始添加model_type列...") - - // 开始事务 - tx, err := db.Begin() - if err != nil { - log.Fatalf("开始事务失败: %v", err) - } - - // 添加model_type列 - _, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`) - if err != nil { - tx.Rollback() - log.Fatalf("添加model_type列失败: %v", err) - } - - // 添加temp_model_type列 - _, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`) - if err != nil { - tx.Rollback() - log.Fatalf("添加temp_model_type列失败: %v", err) - } - - // 根据模型名称推断类型并更新 - rows, err := tx.Query(`SELECT id, model FROM price`) - if err != nil { - tx.Rollback() - log.Fatalf("查询价格数据失败: %v", err) - } - defer rows.Close() - - for rows.Next() { - var id int - var model string - if err := rows.Scan(&id, &model); err != nil { - tx.Rollback() - log.Fatalf("读取行数据失败: %v", err) - } - - // 根据模型名称推断类型 - modelType := inferModelType(model) - - // 更新model_type - _, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id) - if err != nil { - tx.Rollback() - log.Fatalf("更新model_type失败: %v", err) - } - } - - // 提交事务 - if err := tx.Commit(); err != nil { - log.Fatalf("提交事务失败: %v", err) - } - - log.Println("成功添加并更新model_type列") - } else { - log.Println("model_type列已存在,无需迁移") - } -} - -// inferModelType 根据模型名称推断模型类型 -func inferModelType(model string) string { - model = strings.ToLower(model) - - switch { - case strings.Contains(model, "gpt") || - strings.Contains(model, "llama") || - strings.Contains(model, "claude") || - strings.Contains(model, "palm") || - strings.Contains(model, "gemini") || - strings.Contains(model, "qwen") || - strings.Contains(model, "chatglm"): - return "text2text" - - case strings.Contains(model, "dall-e") || - strings.Contains(model, "stable") || - strings.Contains(model, "midjourney") || - strings.Contains(model, "sd") || - strings.Contains(model, "diffusion"): - return "text2image" - - case strings.Contains(model, "whisper") || - strings.Contains(model, "speech") || - strings.Contains(model, "tts"): - return "text2speech" - - case strings.Contains(model, "embedding") || - strings.Contains(model, "ada") || - strings.Contains(model, "text-embedding"): - return "embedding" - - default: - return "other" - } -} diff --git a/scripts/start.sh b/scripts/start.sh index f0db2a6..0a9e0bc 100644 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -1,9 +1,5 @@ #!/bin/bash -# 执行数据库迁移 -echo "执行数据库迁移..." -./migrate - # 启动后端服务 echo "启动后端服务..." ./main &