diff --git a/backend/cron/openrouter-api/update-other-price.go b/backend/cron/openrouter-api/update-other-price.go index 349428d..a036eff 100644 --- a/backend/cron/openrouter-api/update-other-price.go +++ b/backend/cron/openrouter-api/update-other-price.go @@ -98,6 +98,9 @@ func UpdateOtherPrices() error { modelDataMap[modelKey][isFree] = &modelData } + // 创建一个集合用于跟踪已处理的模型,避免重复 + processedModels := make(map[string]bool) + // 第二遍遍历,根据处理规则选择合适的模型数据 for modelKey, variants := range modelDataMap { var modelData *ModelData @@ -159,6 +162,19 @@ func UpdateOtherPrices() error { } } + // 创建唯一标识符,用于避免同厂商同模型重复处理 + uniqueModelKey := fmt.Sprintf("%d:%s", channelType, modelName) + + // 检查是否已处理过这个模型 + if processedModels[uniqueModelKey] { + log.Printf("跳过已处理的模型: %s (厂商: %s)", modelName, author) + skippedCount++ + continue + } + + // 标记此模型为已处理 + processedModels[uniqueModelKey] = true + // 确定模型类型 modelType := determineModelType(modelData.Modality) @@ -227,6 +243,19 @@ func UpdateOtherPrices() error { skippedCount++ } } else { + // 检查是否存在相同模型名称的待审核记录 + var pendingCount int64 + if err := db.Model(&models.Price{}).Where("model = ? AND channel_type = ? AND status = 'pending'", + modelName, channelType).Count(&pendingCount).Error; err != nil { + log.Printf("检查待审核记录失败 %s: %v", modelName, err) + } + + if pendingCount > 0 { + log.Printf("已存在待审核的相同模型记录,跳过创建: %s (厂商: %s)", modelName, author) + skippedCount++ + continue + } + // 使用processPrice函数处理创建 _, changed, err := handlers.ProcessPrice(price, nil, false, CreatedBy) if err != nil { diff --git a/backend/database/db.go b/backend/database/db.go index 8c497ac..c6f3898 100644 --- a/backend/database/db.go +++ b/backend/database/db.go @@ -246,64 +246,6 @@ func cachePriceRates() { modelMap[price.Model][price.ChannelType] = price } - // 计算倍率 - type PriceRate struct { - Model string `json:"model"` - ModelType string `json:"model_type"` - Type string `json:"type"` - ChannelType uint `json:"channel_type"` - Input float64 `json:"input"` - Output float64 `json:"output"` - } - - var rates []PriceRate - for model, providers := range modelMap { - // 找出基准价格(通常是OpenAI的价格) - var basePrice models.Price - var found bool - for _, price := range providers { - if price.ChannelType == 1 { // 假设OpenAI的ID是1 - basePrice = price - found = true - break - } - } - - if !found { - continue - } - - // 计算其他厂商相对于基准价格的倍率 - for channelType, price := range providers { - if channelType == 1 { - continue // 跳过基准价格 - } - - // 计算输入和输出的倍率 - inputRate := 0.0 - if basePrice.InputPrice > 0 { - inputRate = price.InputPrice / basePrice.InputPrice - } - - outputRate := 0.0 - if basePrice.OutputPrice > 0 { - outputRate = price.OutputPrice / basePrice.OutputPrice - } - - rates = append(rates, PriceRate{ - Model: model, - ModelType: price.ModelType, - Type: price.BillingType, - ChannelType: channelType, - Input: inputRate, - Output: outputRate, - }) - } - } - - GlobalCache.Set("price_rates", rates, 10*time.Minute) - log.Printf("已缓存 %d 个价格倍率", len(rates)) - // 缓存常用的价格查询 cachePriceQueries() } diff --git a/backend/init/main.go b/backend/init/main.go new file mode 100644 index 0000000..afaa7b6 --- /dev/null +++ b/backend/init/main.go @@ -0,0 +1,16 @@ +package init + +import ( + "log" +) + +// RunInitTasks 运行所有初始化任务 +func RunInitTasks() { + // 检查并处理重复的模型名称 + if err := CheckDuplicateModelNames(); err != nil { + log.Printf("检查重复模型名称时发生错误: %v", err) + } + + // 在此处添加其他初始化任务 + // ... +} diff --git a/backend/init/modelname-check.go b/backend/init/modelname-check.go new file mode 100644 index 0000000..a162cf5 --- /dev/null +++ b/backend/init/modelname-check.go @@ -0,0 +1,94 @@ +package init + +import ( + "aimodels-prices/database" + "aimodels-prices/models" + "log" +) + +// CheckDuplicateModelNames 检查数据库中是否存在重复的模型名称,如果有则保留最新的 +func CheckDuplicateModelNames() error { + log.Println("开始检查重复的模型名称...") + db := database.DB + if db == nil { + return nil + } + + // 查找所有具有重复模型名称的厂商ID和模型名称组合 + var duplicates []struct { + ChannelType uint `json:"channel_type"` + Model string `json:"model"` + Count int `json:"count"` + } + + if err := db.Raw(` + SELECT channel_type, model, COUNT(*) as count + FROM price + GROUP BY channel_type, model + HAVING COUNT(*) > 1 + `).Scan(&duplicates).Error; err != nil { + return err + } + + if len(duplicates) == 0 { + log.Println("没有发现重复的模型名称") + return nil + } + + log.Printf("发现 %d 组重复的模型名称,正在处理...", len(duplicates)) + processedCount := 0 + + // 开始事务 + tx := db.Begin() + if tx.Error != nil { + return tx.Error + } + + // 处理每一组重复 + for _, dup := range duplicates { + // 查找具有相同厂商ID和模型名称的所有记录 + var prices []models.Price + if err := tx.Where("channel_type = ? AND model = ?", dup.ChannelType, dup.Model).Order("updated_at DESC").Find(&prices).Error; err != nil { + tx.Rollback() + return err + } + + if len(prices) <= 1 { + continue // 安全检查,实际上这不应该发生 + } + + // 保留最新的记录(按更新时间排序后的第一个),删除其他记录 + latestID := prices[0].ID + log.Printf("保留最新的记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v", + latestID, dup.Model, dup.ChannelType, prices[0].UpdatedAt) + + // 收集要删除的ID + var idsToDelete []uint + for i := 1; i < len(prices); i++ { + idsToDelete = append(idsToDelete, prices[i].ID) + log.Printf("删除重复记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v", + prices[i].ID, dup.Model, dup.ChannelType, prices[i].UpdatedAt) + } + + // 删除重复记录 + if len(idsToDelete) > 0 { + if err := tx.Delete(&models.Price{}, idsToDelete).Error; err != nil { + tx.Rollback() + return err + } + processedCount += len(idsToDelete) + } + } + + // 提交事务 + if err := tx.Commit().Error; err != nil { + tx.Rollback() + return err + } + + // 清除缓存 + database.GlobalCache.Clear() + + log.Printf("重复模型名称处理完成,共删除 %d 条重复记录", processedCount) + return nil +} diff --git a/backend/main.go b/backend/main.go index 3763a2b..2af25cb 100644 --- a/backend/main.go +++ b/backend/main.go @@ -11,6 +11,7 @@ import ( "aimodels-prices/database" "aimodels-prices/handlers" "aimodels-prices/handlers/rates" + initTasks "aimodels-prices/init" "aimodels-prices/middleware" ) @@ -26,6 +27,9 @@ func main() { log.Fatalf("Failed to initialize database: %v", err) } + // 运行初始化任务 + initTasks.RunInitTasks() + // 设置gin模式 if gin.Mode() == gin.ReleaseMode { gin.SetMode(gin.ReleaseMode)