239 lines
6.2 KiB
Go

package main
import (
"database/sql"
"fmt"
"log"
"os"
"aimodels-prices/config"
"aimodels-prices/database"
)
func main() {
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 检查SQLite数据库文件是否存在
if _, err := os.Stat(cfg.SQLitePath); os.IsNotExist(err) {
log.Printf("SQLite database file not found at %s, skipping migration", cfg.SQLitePath)
os.Exit(0)
}
// 连接SQLite数据库
sqliteDB, err := database.InitSQLiteDB(cfg.SQLitePath)
if err != nil {
log.Fatalf("Failed to connect to SQLite database: %v", err)
}
defer sqliteDB.Close()
// 初始化MySQL数据库
if err := database.InitDB(cfg); err != nil {
log.Fatalf("Failed to initialize MySQL database: %v", err)
}
defer database.DB.Close()
// 开始迁移数据
if err := migrateData(sqliteDB, database.DB); err != nil {
log.Fatalf("Failed to migrate data: %v", err)
}
log.Println("Data migration completed successfully!")
}
func migrateData(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
// 迁移用户数据
if err := migrateUsers(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate users: %v", err)
}
// 迁移会话数据
if err := migrateSessions(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate sessions: %v", err)
}
// 迁移提供商数据
if err := migrateProviders(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate providers: %v", err)
}
// 迁移价格数据
if err := migratePrices(sqliteDB, mysqlDB); err != nil {
return fmt.Errorf("failed to migrate prices: %v", err)
}
return nil
}
func migrateUsers(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, username, email, role, created_at, updated_at, deleted_at FROM user")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
username string
email string
role string
createdAt string
updatedAt string
deletedAt sql.NullString
)
if err := rows.Scan(&id, &username, &email, &role, &createdAt, &updatedAt, &deletedAt); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO user (id, username, email, role, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, username, email, role, createdAt, updatedAt, deletedAt.String,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migrateSessions(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, user_id, expires_at, created_at, updated_at, deleted_at FROM session")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id string
userID uint
expiresAt string
createdAt string
updatedAt string
deletedAt sql.NullString
)
if err := rows.Scan(&id, &userID, &expiresAt, &createdAt, &updatedAt, &deletedAt); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO session (id, user_id, expires_at, created_at, updated_at, deleted_at) VALUES (?, ?, ?, ?, ?, ?)",
id, userID, expiresAt, createdAt, updatedAt, deletedAt.String,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migrateProviders(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query("SELECT id, name, icon, created_at, updated_at, created_by FROM provider")
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
name string
icon sql.NullString
createdAt string
updatedAt string
createdBy string
)
if err := rows.Scan(&id, &name, &icon, &createdAt, &updatedAt, &createdBy); err != nil {
return err
}
_, err = mysqlDB.Exec(
"INSERT INTO provider (id, name, icon, created_at, updated_at, created_by) VALUES (?, ?, ?, ?, ?, ?)",
id, name, icon.String, createdAt, updatedAt, createdBy,
)
if err != nil {
return err
}
}
return rows.Err()
}
func migratePrices(sqliteDB *sql.DB, mysqlDB *sql.DB) error {
rows, err := sqliteDB.Query(`
SELECT id, model, model_type, billing_type, channel_type, currency,
input_price, output_price, price_source, status, created_at, updated_at,
created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type,
temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price
`)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var (
id uint
model string
modelType string
billingType string
channelType string
currency string
inputPrice float64
outputPrice float64
priceSource string
status string
createdAt string
updatedAt string
createdBy string
tempModel sql.NullString
tempModelType sql.NullString
tempBillingType sql.NullString
tempChannelType sql.NullString
tempCurrency sql.NullString
tempInputPrice sql.NullFloat64
tempOutputPrice sql.NullFloat64
tempPriceSource sql.NullString
updatedBy sql.NullString
)
if err := rows.Scan(
&id, &model, &modelType, &billingType, &channelType, &currency,
&inputPrice, &outputPrice, &priceSource, &status, &createdAt, &updatedAt,
&createdBy, &tempModel, &tempModelType, &tempBillingType, &tempChannelType,
&tempCurrency, &tempInputPrice, &tempOutputPrice, &tempPriceSource, &updatedBy,
); err != nil {
return err
}
_, err = mysqlDB.Exec(`
INSERT INTO price (
id, model, model_type, billing_type, channel_type, currency,
input_price, output_price, price_source, status, created_at, updated_at,
created_by, temp_model, temp_model_type, temp_billing_type, temp_channel_type,
temp_currency, temp_input_price, temp_output_price, temp_price_source, updated_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
id, model, modelType, billingType, channelType, currency,
inputPrice, outputPrice, priceSource, status, createdAt, updatedAt,
createdBy, tempModel.String, tempModelType.String, tempBillingType.String, tempChannelType.String,
tempCurrency.String, tempInputPrice.Float64, tempOutputPrice.Float64, tempPriceSource.String, updatedBy.String,
)
if err != nil {
return err
}
}
return rows.Err()
}