From 6fa37f6d6a90d871ef261cd1f13c81e3ab6e638c Mon Sep 17 00:00:00 2001 From: wood chen Date: Fri, 21 Feb 2025 11:51:53 +0800 Subject: [PATCH] Add model type support for pricing management --- .github/workflows/docker-build.yml | 8 +- .gitignore | 1 + Dockerfile | 10 ++ backend/handlers/model_type.go | 61 ++++++++++ backend/handlers/prices.go | 33 +++--- backend/main.go | 7 ++ backend/models/price.go | 4 + backend/scripts/migrate.go | 174 +++++++++++++++++++++++++++++ frontend/src/views/Home.vue | 2 +- frontend/src/views/Prices.vue | 108 ++++++++++++++++-- prices.json | 1 - scripts/start.sh | 4 + 12 files changed, 384 insertions(+), 29 deletions(-) create mode 100644 backend/handlers/model_type.go create mode 100644 backend/scripts/migrate.go delete mode 100644 prices.json diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 667eb22..e0e90db 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -21,9 +21,9 @@ jobs: # 设置 Go 环境 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.23' # 构建后端(使用 Alpine 环境) - name: Build backend @@ -31,12 +31,14 @@ jobs: cd backend GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o main-amd64 . GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o main-arm64 . + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o migrate-amd64 scripts/migrate.go + GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o migrate-arm64 scripts/migrate.go # 设置 Node.js 环境 - name: Set up Node.js uses: actions/setup-node@v4 with: - node-version: '18' + node-version: '22' cache: 'npm' cache-dependency-path: frontend/package-lock.json diff --git a/.gitignore b/.gitignore index d9ee314..d2e65bd 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ out/ # 日志文件 *.log logs/ +backend/data/aimodels.db diff --git a/Dockerfile b/Dockerfile index 631a15c..221f6a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,6 +24,16 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \ rm main-* && \ chmod +x main +# 复制迁移工具 +COPY backend/migrate-* ./ +RUN if [ "$(uname -m)" = "aarch64" ]; then \ + cp migrate-arm64 migrate; \ + else \ + cp migrate-amd64 migrate; \ + fi && \ + rm migrate-* && \ + chmod +x migrate + COPY frontend/dist /app/frontend COPY backend/config/nginx.conf /etc/nginx/nginx.conf COPY scripts/start.sh ./ diff --git a/backend/handlers/model_type.go b/backend/handlers/model_type.go new file mode 100644 index 0000000..511cdef --- /dev/null +++ b/backend/handlers/model_type.go @@ -0,0 +1,61 @@ +package handlers + +import ( + "database/sql" + + "github.com/gin-gonic/gin" +) + +// ModelType 模型类型结构 +type ModelType struct { + Key string `json:"key"` + Label string `json:"label"` +} + +// GetModelTypes 获取所有模型类型 +func GetModelTypes(c *gin.Context) { + db := c.MustGet("db").(*sql.DB) + + rows, err := db.Query("SELECT key, label FROM model_type") + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + defer rows.Close() + + var types []ModelType + for rows.Next() { + var t ModelType + if err := rows.Scan(&t.Key, &t.Label); err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + types = append(types, t) + } + + c.JSON(200, types) +} + +// CreateModelType 添加新的模型类型 +func CreateModelType(c *gin.Context) { + db := c.MustGet("db").(*sql.DB) + + var newType ModelType + if err := c.ShouldBindJSON(&newType); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + + _, err := db.Exec(` + INSERT INTO model_type (key, label) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET label = excluded.label + `, newType.Key, newType.Label) + + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + + c.JSON(201, newType) +} diff --git a/backend/handlers/prices.go b/backend/handlers/prices.go index f54b927..57dc73a 100644 --- a/backend/handlers/prices.go +++ b/backend/handlers/prices.go @@ -50,9 +50,9 @@ func GetPrices(c *gin.Context) { // 使用分页查询 query := ` - SELECT id, model, billing_type, channel_type, currency, input_price, output_price, + SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price, price_source, status, created_at, updated_at, created_by, - temp_model, temp_billing_type, temp_channel_type, temp_currency, + temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by FROM price` if whereClause != "" { @@ -72,10 +72,10 @@ func GetPrices(c *gin.Context) { for rows.Next() { var price models.Price if err := rows.Scan( - &price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency, + &price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency, &price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status, &price.CreatedAt, &price.UpdatedAt, &price.CreatedBy, - &price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency, + &price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency, &price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"}) return @@ -107,10 +107,10 @@ func CreatePrice(c *gin.Context) { now := time.Now() result, err := db.Exec(` - INSERT INTO price (model, billing_type, channel_type, currency, input_price, output_price, + INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price, price_source, status, created_by, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`, - price.Model, price.BillingType, price.ChannelType, price.Currency, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`, + price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy, now, now) if err != nil { @@ -146,6 +146,7 @@ func UpdatePriceStatus(c *gin.Context) { _, err := db.Exec(` UPDATE price SET model = COALESCE(temp_model, model), + model_type = COALESCE(temp_model_type, model_type), billing_type = COALESCE(temp_billing_type, billing_type), channel_type = COALESCE(temp_channel_type, channel_type), currency = COALESCE(temp_currency, currency), @@ -155,6 +156,7 @@ func UpdatePriceStatus(c *gin.Context) { status = ?, updated_at = ?, temp_model = NULL, + temp_model_type = NULL, temp_billing_type = NULL, temp_channel_type = NULL, temp_currency = NULL, @@ -174,6 +176,7 @@ func UpdatePriceStatus(c *gin.Context) { SET status = ?, updated_at = ?, temp_model = NULL, + temp_model_type = NULL, temp_billing_type = NULL, temp_channel_type = NULL, temp_currency = NULL, @@ -225,11 +228,11 @@ func UpdatePrice(c *gin.Context) { // 将新的价格信息存储到临时字段 _, err = db.Exec(` UPDATE price - SET temp_model = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?, + SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?, updated_by = ?, updated_at = ?, status = 'pending' WHERE id = ?`, - price.Model, price.BillingType, price.ChannelType, price.Currency, + price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, price.InputPrice, price.OutputPrice, price.PriceSource, currentUser.Username, now, id) if err != nil { @@ -239,15 +242,15 @@ func UpdatePrice(c *gin.Context) { // 获取更新后的价格信息 err = db.QueryRow(` - SELECT id, model, billing_type, channel_type, currency, input_price, output_price, + SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price, price_source, status, created_at, updated_at, created_by, - temp_model, temp_billing_type, temp_channel_type, temp_currency, + temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by FROM price WHERE id = ?`, id).Scan( - &price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency, + &price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency, &price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status, &price.CreatedAt, &price.UpdatedAt, &price.CreatedBy, - &price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency, + &price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency, &price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"}) @@ -274,6 +277,7 @@ func DeletePrice(c *gin.Context) { // PriceRate 价格倍率结构 type PriceRate struct { Model string `json:"model"` + ModelType string `json:"model_type"` Type string `json:"type"` ChannelType uint `json:"channel_type"` Input float64 `json:"input"` @@ -284,7 +288,7 @@ type PriceRate struct { func GetPriceRates(c *gin.Context) { db := c.MustGet("db").(*sql.DB) rows, err := db.Query(` - SELECT model, billing_type, channel_type, + SELECT model, model_type, billing_type, channel_type, CASE WHEN currency = 'USD' THEN input_price / 2 ELSE input_price / 14 @@ -307,6 +311,7 @@ func GetPriceRates(c *gin.Context) { var rate PriceRate if err := rows.Scan( &rate.Model, + &rate.ModelType, &rate.Type, &rate.ChannelType, &rate.Input, diff --git a/backend/main.go b/backend/main.go index b30509c..e2b5fee 100644 --- a/backend/main.go +++ b/backend/main.go @@ -88,6 +88,13 @@ func main() { auth.GET("/user", handlers.GetUser) auth.GET("/callback", handlers.AuthCallback) } + + // 模型类型相关路由 + modelTypes := api.Group("/model-types") + { + modelTypes.GET("", handlers.GetModelTypes) + modelTypes.POST("", middleware.AuthRequired(), handlers.CreateModelType) + } } // 启动服务器 diff --git a/backend/models/price.go b/backend/models/price.go index 8e46e61..0c7ca05 100644 --- a/backend/models/price.go +++ b/backend/models/price.go @@ -7,6 +7,7 @@ import ( type Price struct { ID uint `json:"id"` Model string `json:"model"` + ModelType string `json:"model_type"` // text2text, text2image, etc. BillingType string `json:"billing_type"` // tokens or times ChannelType string `json:"channel_type"` Currency string `json:"currency"` // USD or CNY @@ -19,6 +20,7 @@ type Price struct { CreatedBy string `json:"created_by"` // 临时字段,用于存储待审核的更新 TempModel *string `json:"temp_model,omitempty"` + TempModelType *string `json:"temp_model_type,omitempty"` TempBillingType *string `json:"temp_billing_type,omitempty"` TempChannelType *string `json:"temp_channel_type,omitempty"` TempCurrency *string `json:"temp_currency,omitempty"` @@ -34,6 +36,7 @@ func CreatePriceTableSQL() string { CREATE TABLE IF NOT EXISTS price ( id INTEGER PRIMARY KEY AUTOINCREMENT, model TEXT NOT NULL, + model_type TEXT NOT NULL, billing_type TEXT NOT NULL, channel_type TEXT NOT NULL, currency TEXT NOT NULL, @@ -45,6 +48,7 @@ func CreatePriceTableSQL() string { updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_by TEXT NOT NULL, temp_model TEXT, + temp_model_type TEXT, temp_billing_type TEXT, temp_channel_type TEXT, temp_currency TEXT, diff --git a/backend/scripts/migrate.go b/backend/scripts/migrate.go new file mode 100644 index 0000000..a8ebf05 --- /dev/null +++ b/backend/scripts/migrate.go @@ -0,0 +1,174 @@ +package main + +import ( + "database/sql" + "log" + "os" + "path/filepath" + "strings" + + _ "modernc.org/sqlite" +) + +// ModelType 模型类型结构 +type ModelType struct { + Key string `json:"key"` + Label string `json:"label"` +} + +func main() { + // 确保数据目录存在 + dbDir := "./data" + if err := os.MkdirAll(dbDir, 0755); err != nil { + log.Fatalf("创建数据目录失败: %v", err) + } + + // 连接数据库 + dbPath := filepath.Join(dbDir, "aimodels.db") + db, err := sql.Open("sqlite", dbPath) + if err != nil { + log.Fatalf("连接数据库失败: %v", err) + } + defer db.Close() + + // 创建model_type表 + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS model_type ( + key TEXT PRIMARY KEY, + label TEXT NOT NULL + ) + `) + if err != nil { + log.Fatalf("创建model_type表失败: %v", err) + } + + // 初始化默认的模型类型 + defaultTypes := []ModelType{ + {Key: "text2text", Label: "文生文"}, + {Key: "text2image", Label: "文生图"}, + {Key: "text2speech", Label: "文生音"}, + {Key: "speech2text", Label: "音生文"}, + {Key: "image2text", Label: "图生文"}, + {Key: "embedding", Label: "向量"}, + {Key: "other", Label: "其他"}, + } + + // 插入默认类型 + for _, t := range defaultTypes { + _, err = db.Exec(` + INSERT OR IGNORE INTO model_type (key, label) + VALUES (?, ?) + `, t.Key, t.Label) + if err != nil { + log.Printf("插入默认类型失败 %s: %v", t.Key, err) + } + } + + // 检查model_type列是否存在 + var hasModelType bool + err = db.QueryRow(` + SELECT COUNT(*) > 0 + FROM pragma_table_info('price') + WHERE name = 'model_type' + `).Scan(&hasModelType) + if err != nil { + log.Fatalf("检查model_type列失败: %v", err) + } + + // 如果model_type列不存在,则添加它 + if !hasModelType { + log.Println("开始添加model_type列...") + + // 开始事务 + tx, err := db.Begin() + if err != nil { + log.Fatalf("开始事务失败: %v", err) + } + + // 添加model_type列 + _, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`) + if err != nil { + tx.Rollback() + log.Fatalf("添加model_type列失败: %v", err) + } + + // 添加temp_model_type列 + _, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`) + if err != nil { + tx.Rollback() + log.Fatalf("添加temp_model_type列失败: %v", err) + } + + // 根据模型名称推断类型并更新 + rows, err := tx.Query(`SELECT id, model FROM price`) + if err != nil { + tx.Rollback() + log.Fatalf("查询价格数据失败: %v", err) + } + defer rows.Close() + + for rows.Next() { + var id int + var model string + if err := rows.Scan(&id, &model); err != nil { + tx.Rollback() + log.Fatalf("读取行数据失败: %v", err) + } + + // 根据模型名称推断类型 + modelType := inferModelType(model) + + // 更新model_type + _, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id) + if err != nil { + tx.Rollback() + log.Fatalf("更新model_type失败: %v", err) + } + } + + // 提交事务 + if err := tx.Commit(); err != nil { + log.Fatalf("提交事务失败: %v", err) + } + + log.Println("成功添加并更新model_type列") + } else { + log.Println("model_type列已存在,无需迁移") + } +} + +// inferModelType 根据模型名称推断模型类型 +func inferModelType(model string) string { + model = strings.ToLower(model) + + switch { + case strings.Contains(model, "gpt") || + strings.Contains(model, "llama") || + strings.Contains(model, "claude") || + strings.Contains(model, "palm") || + strings.Contains(model, "gemini") || + strings.Contains(model, "qwen") || + strings.Contains(model, "chatglm"): + return "text2text" + + case strings.Contains(model, "dall-e") || + strings.Contains(model, "stable") || + strings.Contains(model, "midjourney") || + strings.Contains(model, "sd") || + strings.Contains(model, "diffusion"): + return "text2image" + + case strings.Contains(model, "whisper") || + strings.Contains(model, "speech") || + strings.Contains(model, "tts"): + return "text2speech" + + case strings.Contains(model, "embedding") || + strings.Contains(model, "ada") || + strings.Contains(model, "text-embedding"): + return "embedding" + + default: + return "other" + } +} diff --git a/frontend/src/views/Home.vue b/frontend/src/views/Home.vue index 07b543c..a11958a 100644 --- a/frontend/src/views/Home.vue +++ b/frontend/src/views/Home.vue @@ -24,7 +24,7 @@

API文档

- +
GET diff --git a/frontend/src/views/Prices.vue b/frontend/src/views/Prices.vue index 26bbd45..5c0cb80 100644 --- a/frontend/src/views/Prices.vue +++ b/frontend/src/views/Prices.vue @@ -68,6 +68,16 @@
+ + + - - - - - - + + +