Remove SQLite migration and related code

This commit is contained in:
wood chen 2025-02-23 04:03:20 +08:00
parent c9a9e7b845
commit 330963418e
9 changed files with 32 additions and 463 deletions

View File

@ -24,16 +24,6 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \
rm main-* && \
chmod +x main
# 复制迁移工具
COPY backend/migrate-* ./
RUN if [ "$(uname -m)" = "aarch64" ]; then \
cp migrate-arm64 migrate; \
else \
cp migrate-amd64 migrate; \
fi && \
rm migrate-* && \
chmod +x migrate
COPY frontend/dist /app/frontend
COPY backend/config/nginx.conf /etc/nginx/nginx.conf
COPY scripts/start.sh ./

View File

@ -1,238 +0,0 @@
package main
import (
"database/sql"
"fmt"
"log"
"os"
"aimodels-prices/config"
"aimodels-prices/database"
)
func main() {
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 检查SQLite数据库文件是否存在
if _, err := os.Stat(cfg.SQLitePath); os.IsNotExist(err) {
log.Printf("SQLite database file not found at %s, skipping migration", cfg.SQLitePath)
os.Exit(0)
}
// 连接SQLite数据库
sqliteDB, err := database.InitSQLiteDB(cfg.SQLitePath)
if err != nil {
log.Fatalf("Failed to connect to SQLite database: %v", err)
}
defer sqliteDB.Close()
// 初始化MySQL数据库
if err := database.InitDB(cfg); err != nil {
log.Fatalf("Failed to initialize MySQL database: %v", err)
}
defer database.DB.Close()
// 开始迁移数据
if err := migrateData(sqliteDB, database.DB); err != nil {
log.Fatalf("Failed to migrate data: %v", err)
}
log.Println("Data migration completed successfully!")
}
func migrateData(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
// 迁移用户数据
if err := migrateUsers(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate users: %v", err)
}
// 迁移会话数据
if err := migrateSessions(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate sessions: %v", err)
}
// 迁移提供商数据
if err := migrateProviders(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate providers: %v", err)
}
// 迁移价格数据
if err := migratePrices(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate prices: %v", err)
}
return nil
}
func migrateUsers(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, username, email, role, created_at, updated_at, deleted_at FROM user")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
username string
email string
role string
createdAt string
updatedAt string
deletedAt sql.NullString
)
if err := rows.Scan(&id, &username, &email, &role, &createdAt, &updatedAt, &deletedAt); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO user (id, username, email, role, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, username, email, role, createdAt, updatedAt, deletedAt.String,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migrateSessions(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id string
userID uint
expiresAt string
createdAt string
updatedAt string
deletedAt sql.NullString
)
if err := rows.Scan(&id, &userID, &expiresAt, &createdAt, &updatedAt, &deletedAt); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO session (id, user_id, expires_at, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?)",
id, userID, expiresAt, createdAt, updatedAt, deletedAt.String,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migrateProviders(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, name, icon, created_at, updated_at, created_by FROM provider")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
name string
icon sql.NullString
createdAt string
updatedAt string
createdBy string
)
if err := rows.Scan(&id, &name, &icon, &createdAt, &updatedAt, &createdBy); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO provider (id, name, icon, created_at, updated_at, created_by) VALUES (?, ?, ?, ?, ?, ?)",
id, name, icon.String, createdAt, updatedAt, createdBy,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migratePrices(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query(`
SELECT id, model, model_type, billing_type, channel_type, currency,
input_price, output_price, price_source, status, created_at, updated_at,
created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type,
temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price
`)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
model string
modelType string
billingType string
channelType uint
currency string
inputPrice float64
outputPrice float64
priceSource string
status string
createdAt string
updatedAt string
createdBy string
tempModel sql.NullString
tempModelType sql.NullString
tempBillingType sql.NullString
tempChannelType sql.NullInt64
tempCurrency sql.NullString
tempInputPrice sql.NullFloat64
tempOutputPrice sql.NullFloat64
tempPriceSource sql.NullString
updatedBy sql.NullString
)
if err := rows.Scan(
&id, &model, &modelType, &billingType, &channelType, &currency,
&inputPrice, &outputPrice, &priceSource, &status, &createdAt, &updatedAt,
&createdBy, &tempModel, &tempModelType, &tempBillingType, &tempChannelType,
&tempCurrency, &tempInputPrice, &tempOutputPrice, &tempPriceSource, &updatedBy,
); err != nil {
return err
}
_, err = mysqlDB.Exec(`
INSERT INTO price (
id, model, model_type, billing_type, channel_type, currency,
input_price, output_price, price_source, status, created_at, updated_at,
created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type,
temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
id, model, modelType, billingType, channelType, currency,
inputPrice, outputPrice, priceSource, status, createdAt, updatedAt,
createdBy, tempModel.String, tempModelType.String, tempBillingType.String, tempChannelType.Int64,
tempCurrency.String, tempInputPrice.Float64, tempOutputPrice.Float64, tempPriceSource.String, updatedBy.String,
)
if err != nil {
return err
}
}
return rows.Err()
}

View File

@ -18,9 +18,6 @@ type Config struct {
// 其他配置
ServerPort string
// SQLite配置用于数据迁移
SQLitePath string
}
func LoadConfig() (*Config, error) {
@ -50,9 +47,6 @@ func LoadConfig() (*Config, error) {
// 其他配置
ServerPort: getEnv("PORT", "8080"),
// SQLite路径用于数据迁移
SQLitePath: filepath.Join(dbDir, "aimodels.db"),
}
return config, nil

View File

@ -6,7 +6,6 @@ import (
"log"
_ "github.com/go-sql-driver/mysql"
_ "modernc.org/sqlite"
"aimodels-prices/config"
"aimodels-prices/models"
@ -51,20 +50,6 @@ func InitDB(cfg *config.Config) error {
return nil
}
// InitSQLiteDB 初始化SQLite数据库连接用于数据迁移
func InitSQLiteDB(dbPath string) (*sql.DB, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to connect to SQLite: %v", err)
}
if err = db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping SQLite: %v", err)
}
return db, nil
}
// createTables 创建数据库表
func createTables() error {
// 创建用户表
@ -85,6 +70,12 @@ func createTables() error {
return err
}
// 创建模型类型表
if _, err := DB.Exec(models.CreateModelTypeTableSQL()); err != nil {
log.Printf("Failed to create model_type table: %v", err)
return err
}
// 创建价格表
if _, err := DB.Exec(models.CreatePriceTableSQL()); err != nil {
log.Printf("Failed to create price table: %v", err)

View File

@ -6,47 +6,31 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/go-sql-driver/mysql v1.7.1
github.com/joho/godotenv v1.5.1
modernc.org/sqlite v1.28.0
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/uint128 v1.2.0 // indirect
modernc.org/cc/v3 v3.40.0 // indirect
modernc.org/ccgo/v3 v3.16.13 // indirect
modernc.org/libc v1.29.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.7.2 // indirect
modernc.org/opt v0.1.3 // indirect
modernc.org/strutil v1.1.3 // indirect
modernc.org/token v1.0.1 // indirect
)

View File

@ -0,0 +1,16 @@
package models
// ModelType 模型类型结构
type ModelType struct {
TypeKey string `json:"key"`
TypeLabel string `json:"label"`
}
// CreateModelTypeTableSQL 返回创建模型类型表的 SQL
func CreateModelTypeTableSQL() string {
return `
CREATE TABLE IF NOT EXISTS model_type (
type_key VARCHAR(50) PRIMARY KEY,
type_label VARCHAR(255) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci`
}

10
backend/scripts/init.sql Normal file
View File

@ -0,0 +1,10 @@
-- 初始化模型类型数据
INSERT INTO model_type (type_key, type_label) VALUES
('text2text', '文生文'),
('text2image', '文生图'),
('text2speech', '文生音'),
('speech2text', '音生文'),
('image2text', '图生文'),
('embedding', '向量'),
('other', '其他')
ON DUPLICATE KEY UPDATE type_label = VALUES(type_label);

View File

@ -1,174 +0,0 @@
package main
import (
"database/sql"
"log"
"os"
"path/filepath"
"strings"
_ "modernc.org/sqlite"
)
// ModelType 模型类型结构
type ModelType struct {
Key string `json:"key"`
Label string `json:"label"`
}
func main() {
// 确保数据目录存在
dbDir := "./data"
if err := os.MkdirAll(dbDir, 0755); err != nil {
log.Fatalf("创建数据目录失败: %v", err)
}
// 连接数据库
dbPath := filepath.Join(dbDir, "aimodels.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
log.Fatalf("连接数据库失败: %v", err)
}
defer db.Close()
// 创建model_type表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS model_type (
key TEXT PRIMARY KEY,
label TEXT NOT NULL
)
`)
if err != nil {
log.Fatalf("创建model_type表失败: %v", err)
}
// 初始化默认的模型类型
defaultTypes := []ModelType{
{Key: "text2text", Label: "文生文"},
{Key: "text2image", Label: "文生图"},
{Key: "text2speech", Label: "文生音"},
{Key: "speech2text", Label: "音生文"},
{Key: "image2text", Label: "图生文"},
{Key: "embedding", Label: "向量"},
{Key: "other", Label: "其他"},
}
// 插入默认类型
for _, t := range defaultTypes {
_, err = db.Exec(`
INSERT OR IGNORE INTO model_type (key, label)
VALUES (?, ?)
`, t.Key, t.Label)
if err != nil {
log.Printf("插入默认类型失败 %s: %v", t.Key, err)
}
}
// 检查model_type列是否存在
var hasModelType bool
err = db.QueryRow(`
SELECT COUNT(*) > 0
FROM pragma_table_info('price')
WHERE name = 'model_type'
`).Scan(&hasModelType)
if err != nil {
log.Fatalf("检查model_type列失败: %v", err)
}
// 如果model_type列不存在,则添加它
if !hasModelType {
log.Println("开始添加model_type列...")
// 开始事务
tx, err := db.Begin()
if err != nil {
log.Fatalf("开始事务失败: %v", err)
}
// 添加model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加model_type列失败: %v", err)
}
// 添加temp_model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加temp_model_type列失败: %v", err)
}
// 根据模型名称推断类型并更新
rows, err := tx.Query(`SELECT id, model FROM price`)
if err != nil {
tx.Rollback()
log.Fatalf("查询价格数据失败: %v", err)
}
defer rows.Close()
for rows.Next() {
var id int
var model string
if err := rows.Scan(&id, &model); err != nil {
tx.Rollback()
log.Fatalf("读取行数据失败: %v", err)
}
// 根据模型名称推断类型
modelType := inferModelType(model)
// 更新model_type
_, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id)
if err != nil {
tx.Rollback()
log.Fatalf("更新model_type失败: %v", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
log.Fatalf("提交事务失败: %v", err)
}
log.Println("成功添加并更新model_type列")
} else {
log.Println("model_type列已存在,无需迁移")
}
}
// inferModelType 根据模型名称推断模型类型
func inferModelType(model string) string {
model = strings.ToLower(model)
switch {
case strings.Contains(model, "gpt") ||
strings.Contains(model, "llama") ||
strings.Contains(model, "claude") ||
strings.Contains(model, "palm") ||
strings.Contains(model, "gemini") ||
strings.Contains(model, "qwen") ||
strings.Contains(model, "chatglm"):
return "text2text"
case strings.Contains(model, "dall-e") ||
strings.Contains(model, "stable") ||
strings.Contains(model, "midjourney") ||
strings.Contains(model, "sd") ||
strings.Contains(model, "diffusion"):
return "text2image"
case strings.Contains(model, "whisper") ||
strings.Contains(model, "speech") ||
strings.Contains(model, "tts"):
return "text2speech"
case strings.Contains(model, "embedding") ||
strings.Contains(model, "ada") ||
strings.Contains(model, "text-embedding"):
return "embedding"
default:
return "other"
}
}

View File

@ -1,9 +1,5 @@
#!/bin/bash
# 执行数据库迁移
echo "执行数据库迁移..."
./migrate
# 启动后端服务
echo "启动后端服务..."
./main &