175 lines
4.0 KiB
Go

package main
import (
"database/sql"
"log"
"os"
"path/filepath"
"strings"
_ "modernc.org/sqlite"
)
// ModelType 模型类型结构
type ModelType struct {
Key string `json:"key"`
Label string `json:"label"`
}
func main() {
// 确保数据目录存在
dbDir := "./data"
if err := os.MkdirAll(dbDir, 0755); err != nil {
log.Fatalf("创建数据目录失败: %v", err)
}
// 连接数据库
dbPath := filepath.Join(dbDir, "aimodels.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
log.Fatalf("连接数据库失败: %v", err)
}
defer db.Close()
// 创建model_type表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS model_type (
key TEXT PRIMARY KEY,
label TEXT NOT NULL
)
`)
if err != nil {
log.Fatalf("创建model_type表失败: %v", err)
}
// 初始化默认的模型类型
defaultTypes := []ModelType{
{Key: "text2text", Label: "文生文"},
{Key: "text2image", Label: "文生图"},
{Key: "text2speech", Label: "文生音"},
{Key: "speech2text", Label: "音生文"},
{Key: "image2text", Label: "图生文"},
{Key: "embedding", Label: "向量"},
{Key: "other", Label: "其他"},
}
// 插入默认类型
for _, t := range defaultTypes {
_, err = db.Exec(`
INSERT OR IGNORE INTO model_type (key, label)
VALUES (?, ?)
`, t.Key, t.Label)
if err != nil {
log.Printf("插入默认类型失败 %s: %v", t.Key, err)
}
}
// 检查model_type列是否存在
var hasModelType bool
err = db.QueryRow(`
SELECT COUNT(*) > 0
FROM pragma_table_info('price')
WHERE name = 'model_type'
`).Scan(&hasModelType)
if err != nil {
log.Fatalf("检查model_type列失败: %v", err)
}
// 如果model_type列不存在,则添加它
if !hasModelType {
log.Println("开始添加model_type列...")
// 开始事务
tx, err := db.Begin()
if err != nil {
log.Fatalf("开始事务失败: %v", err)
}
// 添加model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加model_type列失败: %v", err)
}
// 添加temp_model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加temp_model_type列失败: %v", err)
}
// 根据模型名称推断类型并更新
rows, err := tx.Query(`SELECT id, model FROM price`)
if err != nil {
tx.Rollback()
log.Fatalf("查询价格数据失败: %v", err)
}
defer rows.Close()
for rows.Next() {
var id int
var model string
if err := rows.Scan(&id, &model); err != nil {
tx.Rollback()
log.Fatalf("读取行数据失败: %v", err)
}
// 根据模型名称推断类型
modelType := inferModelType(model)
// 更新model_type
_, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id)
if err != nil {
tx.Rollback()
log.Fatalf("更新model_type失败: %v", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
log.Fatalf("提交事务失败: %v", err)
}
log.Println("成功添加并更新model_type列")
} else {
log.Println("model_type列已存在,无需迁移")
}
}
// inferModelType 根据模型名称推断模型类型
func inferModelType(model string) string {
model = strings.ToLower(model)
switch {
case strings.Contains(model, "gpt") ||
strings.Contains(model, "llama") ||
strings.Contains(model, "claude") ||
strings.Contains(model, "palm") ||
strings.Contains(model, "gemini") ||
strings.Contains(model, "qwen") ||
strings.Contains(model, "chatglm"):
return "text2text"
case strings.Contains(model, "dall-e") ||
strings.Contains(model, "stable") ||
strings.Contains(model, "midjourney") ||
strings.Contains(model, "sd") ||
strings.Contains(model, "diffusion"):
return "text2image"
case strings.Contains(model, "whisper") ||
strings.Contains(model, "speech") ||
strings.Contains(model, "tts"):
return "text2speech"
case strings.Contains(model, "embedding") ||
strings.Contains(model, "ada") ||
strings.Contains(model, "text-embedding"):
return "embedding"
default:
return "other"
}
}