mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
优化价格更新逻辑,增加模型处理的唯一性检查
- 在 UpdateOtherPrices 函数中新增已处理模型的跟踪机制,避免重复处理同一模型 - 增强模型数据处理逻辑,确保在创建新价格记录前检查待审核记录 - 更新 main.go,添加初始化任务的调用,提升系统启动时的功能完整性
This commit is contained in:
parent
35a4936ff8
commit
eb9e069f76
@ -98,6 +98,9 @@ func UpdateOtherPrices() error {
|
|||||||
modelDataMap[modelKey][isFree] = &modelData
|
modelDataMap[modelKey][isFree] = &modelData
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建一个集合用于跟踪已处理的模型,避免重复
|
||||||
|
processedModels := make(map[string]bool)
|
||||||
|
|
||||||
// 第二遍遍历,根据处理规则选择合适的模型数据
|
// 第二遍遍历,根据处理规则选择合适的模型数据
|
||||||
for modelKey, variants := range modelDataMap {
|
for modelKey, variants := range modelDataMap {
|
||||||
var modelData *ModelData
|
var modelData *ModelData
|
||||||
@ -159,6 +162,19 @@ func UpdateOtherPrices() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建唯一标识符,用于避免同厂商同模型重复处理
|
||||||
|
uniqueModelKey := fmt.Sprintf("%d:%s", channelType, modelName)
|
||||||
|
|
||||||
|
// 检查是否已处理过这个模型
|
||||||
|
if processedModels[uniqueModelKey] {
|
||||||
|
log.Printf("跳过已处理的模型: %s (厂商: %s)", modelName, author)
|
||||||
|
skippedCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记此模型为已处理
|
||||||
|
processedModels[uniqueModelKey] = true
|
||||||
|
|
||||||
// 确定模型类型
|
// 确定模型类型
|
||||||
modelType := determineModelType(modelData.Modality)
|
modelType := determineModelType(modelData.Modality)
|
||||||
|
|
||||||
@ -227,6 +243,19 @@ func UpdateOtherPrices() error {
|
|||||||
skippedCount++
|
skippedCount++
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// 检查是否存在相同模型名称的待审核记录
|
||||||
|
var pendingCount int64
|
||||||
|
if err := db.Model(&models.Price{}).Where("model = ? AND channel_type = ? AND status = 'pending'",
|
||||||
|
modelName, channelType).Count(&pendingCount).Error; err != nil {
|
||||||
|
log.Printf("检查待审核记录失败 %s: %v", modelName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pendingCount > 0 {
|
||||||
|
log.Printf("已存在待审核的相同模型记录,跳过创建: %s (厂商: %s)", modelName, author)
|
||||||
|
skippedCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// 使用processPrice函数处理创建
|
// 使用processPrice函数处理创建
|
||||||
_, changed, err := handlers.ProcessPrice(price, nil, false, CreatedBy)
|
_, changed, err := handlers.ProcessPrice(price, nil, false, CreatedBy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -246,64 +246,6 @@ func cachePriceRates() {
|
|||||||
modelMap[price.Model][price.ChannelType] = price
|
modelMap[price.Model][price.ChannelType] = price
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算倍率
|
|
||||||
type PriceRate struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
ChannelType uint `json:"channel_type"`
|
|
||||||
Input float64 `json:"input"`
|
|
||||||
Output float64 `json:"output"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var rates []PriceRate
|
|
||||||
for model, providers := range modelMap {
|
|
||||||
// 找出基准价格(通常是OpenAI的价格)
|
|
||||||
var basePrice models.Price
|
|
||||||
var found bool
|
|
||||||
for _, price := range providers {
|
|
||||||
if price.ChannelType == 1 { // 假设OpenAI的ID是1
|
|
||||||
basePrice = price
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算其他厂商相对于基准价格的倍率
|
|
||||||
for channelType, price := range providers {
|
|
||||||
if channelType == 1 {
|
|
||||||
continue // 跳过基准价格
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算输入和输出的倍率
|
|
||||||
inputRate := 0.0
|
|
||||||
if basePrice.InputPrice > 0 {
|
|
||||||
inputRate = price.InputPrice / basePrice.InputPrice
|
|
||||||
}
|
|
||||||
|
|
||||||
outputRate := 0.0
|
|
||||||
if basePrice.OutputPrice > 0 {
|
|
||||||
outputRate = price.OutputPrice / basePrice.OutputPrice
|
|
||||||
}
|
|
||||||
|
|
||||||
rates = append(rates, PriceRate{
|
|
||||||
Model: model,
|
|
||||||
ModelType: price.ModelType,
|
|
||||||
Type: price.BillingType,
|
|
||||||
ChannelType: channelType,
|
|
||||||
Input: inputRate,
|
|
||||||
Output: outputRate,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GlobalCache.Set("price_rates", rates, 10*time.Minute)
|
|
||||||
log.Printf("已缓存 %d 个价格倍率", len(rates))
|
|
||||||
|
|
||||||
// 缓存常用的价格查询
|
// 缓存常用的价格查询
|
||||||
cachePriceQueries()
|
cachePriceQueries()
|
||||||
}
|
}
|
||||||
|
16
backend/init/main.go
Normal file
16
backend/init/main.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package init
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunInitTasks 运行所有初始化任务
|
||||||
|
func RunInitTasks() {
|
||||||
|
// 检查并处理重复的模型名称
|
||||||
|
if err := CheckDuplicateModelNames(); err != nil {
|
||||||
|
log.Printf("检查重复模型名称时发生错误: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在此处添加其他初始化任务
|
||||||
|
// ...
|
||||||
|
}
|
94
backend/init/modelname-check.go
Normal file
94
backend/init/modelname-check.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package init
|
||||||
|
|
||||||
|
import (
|
||||||
|
"aimodels-prices/database"
|
||||||
|
"aimodels-prices/models"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckDuplicateModelNames 检查数据库中是否存在重复的模型名称,如果有则保留最新的
|
||||||
|
func CheckDuplicateModelNames() error {
|
||||||
|
log.Println("开始检查重复的模型名称...")
|
||||||
|
db := database.DB
|
||||||
|
if db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找所有具有重复模型名称的厂商ID和模型名称组合
|
||||||
|
var duplicates []struct {
|
||||||
|
ChannelType uint `json:"channel_type"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Raw(`
|
||||||
|
SELECT channel_type, model, COUNT(*) as count
|
||||||
|
FROM price
|
||||||
|
GROUP BY channel_type, model
|
||||||
|
HAVING COUNT(*) > 1
|
||||||
|
`).Scan(&duplicates).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(duplicates) == 0 {
|
||||||
|
log.Println("没有发现重复的模型名称")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("发现 %d 组重复的模型名称,正在处理...", len(duplicates))
|
||||||
|
processedCount := 0
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
tx := db.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
return tx.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理每一组重复
|
||||||
|
for _, dup := range duplicates {
|
||||||
|
// 查找具有相同厂商ID和模型名称的所有记录
|
||||||
|
var prices []models.Price
|
||||||
|
if err := tx.Where("channel_type = ? AND model = ?", dup.ChannelType, dup.Model).Order("updated_at DESC").Find(&prices).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(prices) <= 1 {
|
||||||
|
continue // 安全检查,实际上这不应该发生
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保留最新的记录(按更新时间排序后的第一个),删除其他记录
|
||||||
|
latestID := prices[0].ID
|
||||||
|
log.Printf("保留最新的记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v",
|
||||||
|
latestID, dup.Model, dup.ChannelType, prices[0].UpdatedAt)
|
||||||
|
|
||||||
|
// 收集要删除的ID
|
||||||
|
var idsToDelete []uint
|
||||||
|
for i := 1; i < len(prices); i++ {
|
||||||
|
idsToDelete = append(idsToDelete, prices[i].ID)
|
||||||
|
log.Printf("删除重复记录: ID=%v, 模型=%s, 厂商ID=%d, 更新时间=%v",
|
||||||
|
prices[i].ID, dup.Model, dup.ChannelType, prices[i].UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除重复记录
|
||||||
|
if len(idsToDelete) > 0 {
|
||||||
|
if err := tx.Delete(&models.Price{}, idsToDelete).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
processedCount += len(idsToDelete)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除缓存
|
||||||
|
database.GlobalCache.Clear()
|
||||||
|
|
||||||
|
log.Printf("重复模型名称处理完成,共删除 %d 条重复记录", processedCount)
|
||||||
|
return nil
|
||||||
|
}
|
@ -11,6 +11,7 @@ import (
|
|||||||
"aimodels-prices/database"
|
"aimodels-prices/database"
|
||||||
"aimodels-prices/handlers"
|
"aimodels-prices/handlers"
|
||||||
"aimodels-prices/handlers/rates"
|
"aimodels-prices/handlers/rates"
|
||||||
|
initTasks "aimodels-prices/init"
|
||||||
"aimodels-prices/middleware"
|
"aimodels-prices/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,6 +27,9 @@ func main() {
|
|||||||
log.Fatalf("Failed to initialize database: %v", err)
|
log.Fatalf("Failed to initialize database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 运行初始化任务
|
||||||
|
initTasks.RunInitTasks()
|
||||||
|
|
||||||
// 设置gin模式
|
// 设置gin模式
|
||||||
if gin.Mode() == gin.ReleaseMode {
|
if gin.Mode() == gin.ReleaseMode {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user