mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
175 lines
4.0 KiB
Go
175 lines
4.0 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
// ModelType 模型类型结构
|
|
type ModelType struct {
|
|
Key string `json:"key"`
|
|
Label string `json:"label"`
|
|
}
|
|
|
|
func main() {
|
|
// 确保数据目录存在
|
|
dbDir := "./data"
|
|
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
|
log.Fatalf("创建数据目录失败: %v", err)
|
|
}
|
|
|
|
// 连接数据库
|
|
dbPath := filepath.Join(dbDir, "aimodels.db")
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
log.Fatalf("连接数据库失败: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// 创建model_type表
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS model_type (
|
|
key TEXT PRIMARY KEY,
|
|
label TEXT NOT NULL
|
|
)
|
|
`)
|
|
if err != nil {
|
|
log.Fatalf("创建model_type表失败: %v", err)
|
|
}
|
|
|
|
// 初始化默认的模型类型
|
|
defaultTypes := []ModelType{
|
|
{Key: "text2text", Label: "文生文"},
|
|
{Key: "text2image", Label: "文生图"},
|
|
{Key: "text2speech", Label: "文生音"},
|
|
{Key: "speech2text", Label: "音生文"},
|
|
{Key: "image2text", Label: "图生文"},
|
|
{Key: "embedding", Label: "向量"},
|
|
{Key: "other", Label: "其他"},
|
|
}
|
|
|
|
// 插入默认类型
|
|
for _, t := range defaultTypes {
|
|
_, err = db.Exec(`
|
|
INSERT OR IGNORE INTO model_type (key, label)
|
|
VALUES (?, ?)
|
|
`, t.Key, t.Label)
|
|
if err != nil {
|
|
log.Printf("插入默认类型失败 %s: %v", t.Key, err)
|
|
}
|
|
}
|
|
|
|
// 检查model_type列是否存在
|
|
var hasModelType bool
|
|
err = db.QueryRow(`
|
|
SELECT COUNT(*) > 0
|
|
FROM pragma_table_info('price')
|
|
WHERE name = 'model_type'
|
|
`).Scan(&hasModelType)
|
|
if err != nil {
|
|
log.Fatalf("检查model_type列失败: %v", err)
|
|
}
|
|
|
|
// 如果model_type列不存在,则添加它
|
|
if !hasModelType {
|
|
log.Println("开始添加model_type列...")
|
|
|
|
// 开始事务
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
log.Fatalf("开始事务失败: %v", err)
|
|
}
|
|
|
|
// 添加model_type列
|
|
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Fatalf("添加model_type列失败: %v", err)
|
|
}
|
|
|
|
// 添加temp_model_type列
|
|
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Fatalf("添加temp_model_type列失败: %v", err)
|
|
}
|
|
|
|
// 根据模型名称推断类型并更新
|
|
rows, err := tx.Query(`SELECT id, model FROM price`)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Fatalf("查询价格数据失败: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id int
|
|
var model string
|
|
if err := rows.Scan(&id, &model); err != nil {
|
|
tx.Rollback()
|
|
log.Fatalf("读取行数据失败: %v", err)
|
|
}
|
|
|
|
// 根据模型名称推断类型
|
|
modelType := inferModelType(model)
|
|
|
|
// 更新model_type
|
|
_, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
log.Fatalf("更新model_type失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 提交事务
|
|
if err := tx.Commit(); err != nil {
|
|
log.Fatalf("提交事务失败: %v", err)
|
|
}
|
|
|
|
log.Println("成功添加并更新model_type列")
|
|
} else {
|
|
log.Println("model_type列已存在,无需迁移")
|
|
}
|
|
}
|
|
|
|
// inferModelType 根据模型名称推断模型类型
|
|
func inferModelType(model string) string {
|
|
model = strings.ToLower(model)
|
|
|
|
switch {
|
|
case strings.Contains(model, "gpt") ||
|
|
strings.Contains(model, "llama") ||
|
|
strings.Contains(model, "claude") ||
|
|
strings.Contains(model, "palm") ||
|
|
strings.Contains(model, "gemini") ||
|
|
strings.Contains(model, "qwen") ||
|
|
strings.Contains(model, "chatglm"):
|
|
return "text2text"
|
|
|
|
case strings.Contains(model, "dall-e") ||
|
|
strings.Contains(model, "stable") ||
|
|
strings.Contains(model, "midjourney") ||
|
|
strings.Contains(model, "sd") ||
|
|
strings.Contains(model, "diffusion"):
|
|
return "text2image"
|
|
|
|
case strings.Contains(model, "whisper") ||
|
|
strings.Contains(model, "speech") ||
|
|
strings.Contains(model, "tts"):
|
|
return "text2speech"
|
|
|
|
case strings.Contains(model, "embedding") ||
|
|
strings.Contains(model, "ada") ||
|
|
strings.Contains(model, "text-embedding"):
|
|
return "embedding"
|
|
|
|
default:
|
|
return "other"
|
|
}
|
|
}
|