优化价格更新逻辑,增加模型处理的唯一性检查

- 在 UpdateOtherPrices 函数中新增已处理模型的跟踪机制,避免重复处理同一模型
- 增强模型数据处理逻辑,确保在创建新价格记录前检查待审核记录
- 更新 main.go,添加初始化任务的调用,提升系统启动时的功能完整性
This commit is contained in:
wood chen 2025-03-18 03:01:10 +08:00
parent 35a4936ff8
commit eb9e069f76
5 changed files with 143 additions and 58 deletions

View File

@ -98,6 +98,9 @@ func UpdateOtherPrices() error {
modelDataMap[modelKey][isFree] = &modelData
}
// 创建一个集合用于跟踪已处理的模型,避免重复
processedModels := make(map[string]bool)
// 第二遍遍历,根据处理规则选择合适的模型数据
for modelKey, variants := range modelDataMap {
var modelData *ModelData
@ -159,6 +162,19 @@ func UpdateOtherPrices() error {
}
}
// 创建唯一标识符,用于避免同厂商同模型重复处理
uniqueModelKey := fmt.Sprintf("%d:%s", channelType, modelName)
// 检查是否已处理过这个模型
if processedModels[uniqueModelKey] {
log.Printf("跳过已处理的模型: %s (厂商: %s)", modelName, author)
skippedCount++
continue
}
// 标记此模型为已处理
processedModels[uniqueModelKey] = true
// 确定模型类型
modelType := determineModelType(modelData.Modality)
@ -227,6 +243,19 @@ func UpdateOtherPrices() error {
skippedCount++
}
} else {
// 检查是否存在相同模型名称的待审核记录
var pendingCount int64
if err := db.Model(&models.Price{}).Where("model = ? AND channel_type = ? AND status = 'pending'",
modelName, channelType).Count(&pendingCount).Error; err != nil {
log.Printf("检查待审核记录失败 %s: %v", modelName, err)
}
if pendingCount > 0 {
log.Printf("已存在待审核的相同模型记录,跳过创建: %s (厂商: %s)", modelName, author)
skippedCount++
continue
}
// 使用processPrice函数处理创建
_, changed, err := handlers.ProcessPrice(price, nil, false, CreatedBy)
if err != nil {

View File

@ -246,64 +246,6 @@ func cachePriceRates() {
modelMap[price.Model][price.ChannelType] = price
}
// 计算倍率
type PriceRate struct {
Model string `json:"model"`
ModelType string `json:"model_type"`
Type string `json:"type"`
ChannelType uint `json:"channel_type"`
Input float64 `json:"input"`
Output float64 `json:"output"`
}
var rates []PriceRate
for model, providers := range modelMap {
// 找出基准价格通常是OpenAI的价格
var basePrice models.Price
var found bool
for _, price := range providers {
if price.ChannelType == 1 { // 假设OpenAI的ID是1
basePrice = price
found = true
break
}
}
if !found {
continue
}
// 计算其他厂商相对于基准价格的倍率
for channelType, price := range providers {
if channelType == 1 {
continue // 跳过基准价格
}
// 计算输入和输出的倍率
inputRate := 0.0
if basePrice.InputPrice > 0 {
inputRate = price.InputPrice / basePrice.InputPrice
}
outputRate := 0.0
if basePrice.OutputPrice > 0 {
outputRate = price.OutputPrice / basePrice.OutputPrice
}
rates = append(rates, PriceRate{
Model: model,
ModelType: price.ModelType,
Type: price.BillingType,
ChannelType: channelType,
Input: inputRate,
Output: outputRate,
})
}
}
GlobalCache.Set("price_rates", rates, 10*time.Minute)
log.Printf("已缓存 %d 个价格倍率", len(rates))
// 缓存常用的价格查询
cachePriceQueries()
}

16
backend/init/main.go Normal file
View File

@ -0,0 +1,16 @@
package init
import (
"log"
)
// RunInitTasks 运行所有初始化任务
func RunInitTasks() {
// 检查并处理重复的模型名称
if err := CheckDuplicateModelNames(); err != nil {
log.Printf("检查重复模型名称时发生错误: %v", err)
}
// 在此处添加其他初始化任务
// ...
}

View File

@ -0,0 +1,94 @@
package init
import (
"aimodels-prices/database"
"aimodels-prices/models"
"log"
)
// CheckDuplicateModelNames 检查数据库中是否存在重复的模型名称,如果有则保留最新的
func CheckDuplicateModelNames() error {
log.Println("开始检查重复的模型名称...")
db := database.DB
if db == nil {
return nil
}
// 查找所有具有重复模型名称的厂商ID和模型名称组合
var duplicates []struct {
ChannelType uint `json:"channel_type"`
Model string `json:"model"`
Count int `json:"count"`
}
if err := db.Raw(`
SELECT channel_type, model, COUNT(*) as count
FROM price
GROUP BY channel_type, model
HAVING COUNT(*) > 1
`).Scan(&duplicates).Error; err != nil {
return err
}
if len(duplicates) == 0 {
log.Println("没有发现重复的模型名称")
return nil
}
log.Printf("发现 %d 组重复的模型名称,正在处理...", len(duplicates))
processedCount := 0
// 开始事务
tx := db.Begin()
if tx.Error != nil {
return tx.Error
}
// 处理每一组重复
for _, dup := range duplicates {
// 查找具有相同厂商ID和模型名称的所有记录
var prices []models.Price
if err := tx.Where("channel_type = ? AND model = ?", dup.ChannelType, dup.Model).Order("updated_at DESC").Find(&prices).Error; err != nil {
tx.Rollback()
return err
}
if len(prices) <= 1 {
continue // 安全检查,实际上这不应该发生
}
// 保留最新的记录(按更新时间排序后的第一个),删除其他记录
latestID := prices[0].ID
log.Printf("保留最新的记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v",
latestID, dup.Model, dup.ChannelType, prices[0].UpdatedAt)
// 收集要删除的ID
var idsToDelete []uint
for i := 1; i < len(prices); i++ {
idsToDelete = append(idsToDelete, prices[i].ID)
log.Printf("删除重复记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v",
prices[i].ID, dup.Model, dup.ChannelType, prices[i].UpdatedAt)
}
// 删除重复记录
if len(idsToDelete) > 0 {
if err := tx.Delete(&models.Price{}, idsToDelete).Error; err != nil {
tx.Rollback()
return err
}
processedCount += len(idsToDelete)
}
}
// 提交事务
if err := tx.Commit().Error; err != nil {
tx.Rollback()
return err
}
// 清除缓存
database.GlobalCache.Clear()
log.Printf("重复模型名称处理完成,共删除 %d 条重复记录", processedCount)
return nil
}

View File

@ -11,6 +11,7 @@ import (
"aimodels-prices/database"
"aimodels-prices/handlers"
"aimodels-prices/handlers/rates"
initTasks "aimodels-prices/init"
"aimodels-prices/middleware"
)
@ -26,6 +27,9 @@ func main() {
log.Fatalf("Failed to initialize database: %v", err)
}
// 运行初始化任务
initTasks.RunInitTasks()
// 设置gin模式
if gin.Mode() == gin.ReleaseMode {
gin.SetMode(gin.ReleaseMode)