mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 05:32:00 +08:00
- 在 UpdateOtherPrices 函数中新增已处理模型的跟踪机制,避免重复处理同一模型 - 增强模型数据处理逻辑,确保在创建新价格记录前检查待审核记录 - 更新 main.go,添加初始化任务的调用,提升系统启动时的功能完整性
315 lines
8.6 KiB
Go
315 lines
8.6 KiB
Go
package openrouter_api
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
|
||
"encoding/json"
|
||
"strings"
|
||
|
||
"aimodels-prices/database"
|
||
"aimodels-prices/handlers"
|
||
"aimodels-prices/models"
|
||
)
|
||
|
||
// 定义厂商ID映射
|
||
var authorToChannelType = map[string]uint{
|
||
"openai": 1,
|
||
"anthropic": 14,
|
||
"google": 25,
|
||
"x-ai": 1001,
|
||
}
|
||
|
||
// 定义黑名单列表
|
||
var blacklist = []string{
|
||
"shap-e",
|
||
"palm-2",
|
||
"o3-mini-high",
|
||
"claude-instant",
|
||
"claude-1",
|
||
"claude-3-haiku",
|
||
"claude-3-opus",
|
||
"claude-3-sonnet",
|
||
":extended",
|
||
}
|
||
|
||
const (
|
||
OtherPriceSource = "三方API"
|
||
OtherStatus = "pending"
|
||
)
|
||
|
||
// UpdateOtherPrices 更新其他厂商的价格
|
||
func UpdateOtherPrices() error {
|
||
log.Println("开始更新其他厂商价格数据...")
|
||
|
||
// 复用已有的API请求获取数据
|
||
resp, err := fetchOpenRouterData()
|
||
if err != nil {
|
||
return fmt.Errorf("获取OpenRouter数据失败: %v", err)
|
||
}
|
||
|
||
// 获取数据库连接
|
||
db := database.DB
|
||
if db == nil {
|
||
return fmt.Errorf("获取数据库连接失败")
|
||
}
|
||
|
||
// 处理每个模型的价格数据
|
||
processedCount := 0
|
||
skippedCount := 0
|
||
|
||
// 创建一个映射,用于按作者和模型名称存储模型数据
|
||
// 键:作者/模型名称基础部分
|
||
// 值:带有free标识和不带free标识的模型数据
|
||
modelDataMap := make(map[string]map[bool]*ModelData)
|
||
|
||
// 第一遍遍历,分类整理模型数据
|
||
for _, modelData := range resp.Data {
|
||
// 提取模型名称(slug中/后面的部分)
|
||
parts := strings.Split(modelData.Slug, "/")
|
||
if len(parts) < 2 {
|
||
log.Printf("跳过无效的模型名称: %s", modelData.Slug)
|
||
continue
|
||
}
|
||
|
||
author := parts[0]
|
||
fullModelName := parts[1]
|
||
|
||
// 判断是否带有":free"后缀
|
||
isFree := strings.HasSuffix(fullModelName, ":free")
|
||
|
||
// 提取基础模型名称(不带":free"后缀)
|
||
baseModelName := fullModelName
|
||
if isFree {
|
||
baseModelName = strings.TrimSuffix(fullModelName, ":free")
|
||
}
|
||
|
||
// 创建模型的唯一键
|
||
modelKey := author + "/" + baseModelName
|
||
|
||
// 如果需要,为这个模型键初始化一个条目
|
||
if _, exists := modelDataMap[modelKey]; !exists {
|
||
modelDataMap[modelKey] = make(map[bool]*ModelData)
|
||
}
|
||
|
||
// 存储模型数据
|
||
modelDataMap[modelKey][isFree] = &modelData
|
||
}
|
||
|
||
// 创建一个集合用于跟踪已处理的模型,避免重复
|
||
processedModels := make(map[string]bool)
|
||
|
||
// 第二遍遍历,根据处理规则选择合适的模型数据
|
||
for modelKey, variants := range modelDataMap {
|
||
var modelData *ModelData
|
||
|
||
// 优先选择非free版本
|
||
if nonFreeData, hasNonFree := variants[false]; hasNonFree {
|
||
modelData = nonFreeData
|
||
} else if freeData, hasFree := variants[true]; hasFree {
|
||
// 如果只有free版本,则使用free版本
|
||
modelData = freeData
|
||
} else {
|
||
// 不应该发生,但为了安全
|
||
log.Printf("处理模型数据异常: %s", modelKey)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 提取模型名称
|
||
parts := strings.Split(modelData.Slug, "/")
|
||
modelName := strings.Split(parts[1], ":")[0] // 移除":free"后缀
|
||
author := parts[0]
|
||
|
||
// 检查是否在黑名单中
|
||
if isInBlacklist(modelName) {
|
||
log.Printf("跳过黑名单模型: %s", modelName)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 检查是否支持的厂商
|
||
channelType, ok := authorToChannelType[author]
|
||
if !ok {
|
||
log.Printf("跳过不支持的厂商: %s", author)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 处理特殊模型名称
|
||
if author == "google" {
|
||
// 处理gemini-flash-1.5系列模型名称
|
||
if strings.HasPrefix(modelName, "gemini-flash-1.5") {
|
||
suffix := strings.TrimPrefix(modelName, "gemini-flash-1.5")
|
||
modelName = "gemini-1.5-flash" + suffix
|
||
log.Printf("修正Google模型名称: %s -> %s", parts[1], modelName)
|
||
}
|
||
}
|
||
if author == "anthropic" {
|
||
// 处理claude-3.5-sonnet系列模型名称
|
||
if strings.HasPrefix(modelName, "claude-3.5") {
|
||
suffix := strings.TrimPrefix(modelName, "claude-3.5")
|
||
modelName = "claude-3-5" + suffix
|
||
log.Printf("修正Claude模型名称: %s -> %s", parts[1], modelName)
|
||
}
|
||
|
||
if strings.HasPrefix(modelName, "claude-3.7") {
|
||
suffix := strings.TrimPrefix(modelName, "claude-3.7")
|
||
modelName = "claude-3-7" + suffix
|
||
log.Printf("修正Claude模型名称: %s -> %s", parts[1], modelName)
|
||
}
|
||
}
|
||
|
||
// 创建唯一标识符,用于避免同厂商同模型重复处理
|
||
uniqueModelKey := fmt.Sprintf("%d:%s", channelType, modelName)
|
||
|
||
// 检查是否已处理过这个模型
|
||
if processedModels[uniqueModelKey] {
|
||
log.Printf("跳过已处理的模型: %s (厂商: %s)", modelName, author)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 标记此模型为已处理
|
||
processedModels[uniqueModelKey] = true
|
||
|
||
// 确定模型类型
|
||
modelType := determineModelType(modelData.Modality)
|
||
|
||
// 解析价格
|
||
var inputPrice, outputPrice float64
|
||
var parseErr error
|
||
|
||
// 如果输入或输出价格为空,直接跳过
|
||
if modelData.Endpoint.Pricing.Prompt == "" || modelData.Endpoint.Pricing.Completion == "" {
|
||
log.Printf("跳过价格数据不完整的模型: %s", modelData.Slug)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 使用endpoint中的pricing
|
||
if modelData.Endpoint.Pricing.Prompt != "" {
|
||
inputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Prompt)
|
||
if parseErr != nil {
|
||
log.Printf("解析endpoint输入价格失败 %s: %v", modelData.Slug, parseErr)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
}
|
||
|
||
if modelData.Endpoint.Pricing.Completion != "" {
|
||
outputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Completion)
|
||
if parseErr != nil {
|
||
log.Printf("解析endpoint输出价格失败 %s: %v", modelData.Slug, parseErr)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 创建价格对象
|
||
price := models.Price{
|
||
Model: modelName,
|
||
ModelType: modelType,
|
||
BillingType: BillingType,
|
||
ChannelType: channelType,
|
||
Currency: Currency,
|
||
InputPrice: inputPrice,
|
||
OutputPrice: outputPrice,
|
||
PriceSource: OtherPriceSource,
|
||
Status: OtherStatus,
|
||
CreatedBy: CreatedBy,
|
||
}
|
||
|
||
// 检查是否已存在相同模型的价格记录
|
||
var existingPrice models.Price
|
||
result := db.Where("model = ? AND channel_type = ?", modelName, channelType).First(&existingPrice)
|
||
|
||
if result.Error == nil {
|
||
// 使用processPrice函数处理更新
|
||
_, changed, err := handlers.ProcessPrice(price, &existingPrice, false, CreatedBy)
|
||
if err != nil {
|
||
log.Printf("更新价格记录失败 %s: %v", modelName, err)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
if changed {
|
||
log.Printf("更新价格记录: %s (厂商: %s)", modelName, author)
|
||
processedCount++
|
||
} else {
|
||
log.Printf("价格无变化,跳过更新: %s (厂商: %s)", modelName, author)
|
||
skippedCount++
|
||
}
|
||
} else {
|
||
// 检查是否存在相同模型名称的待审核记录
|
||
var pendingCount int64
|
||
if err := db.Model(&models.Price{}).Where("model = ? AND channel_type = ? AND status = 'pending'",
|
||
modelName, channelType).Count(&pendingCount).Error; err != nil {
|
||
log.Printf("检查待审核记录失败 %s: %v", modelName, err)
|
||
}
|
||
|
||
if pendingCount > 0 {
|
||
log.Printf("已存在待审核的相同模型记录,跳过创建: %s (厂商: %s)", modelName, author)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 使用processPrice函数处理创建
|
||
_, changed, err := handlers.ProcessPrice(price, nil, false, CreatedBy)
|
||
if err != nil {
|
||
log.Printf("创建价格记录失败 %s: %v", modelName, err)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
if changed {
|
||
log.Printf("创建新价格记录: %s (厂商: %s)", modelName, author)
|
||
processedCount++
|
||
} else {
|
||
log.Printf("价格创建失败: %s (厂商: %s)", modelName, author)
|
||
skippedCount++
|
||
}
|
||
}
|
||
}
|
||
|
||
log.Printf("其他厂商价格数据处理完成,成功处理: %d, 跳过: %d", processedCount, skippedCount)
|
||
return nil
|
||
}
|
||
|
||
// fetchOpenRouterData 获取OpenRouter API数据
|
||
func fetchOpenRouterData() (*OpenRouterResponse, error) {
|
||
// 复用已有的HTTP请求逻辑
|
||
resp, err := http.Get(OpenRouterAPIURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求OpenRouter API失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应内容
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应内容失败: %v", err)
|
||
}
|
||
|
||
// 解析JSON数据
|
||
var openRouterResp OpenRouterResponse
|
||
if err := json.Unmarshal(body, &openRouterResp); err != nil {
|
||
return nil, fmt.Errorf("解析JSON数据失败: %v", err)
|
||
}
|
||
|
||
return &openRouterResp, nil
|
||
}
|
||
|
||
// isInBlacklist 检查模型名称是否在黑名单中
|
||
func isInBlacklist(modelName string) bool {
|
||
modelNameLower := strings.ToLower(modelName)
|
||
for _, blacklistItem := range blacklist {
|
||
if strings.Contains(modelNameLower, blacklistItem) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|