aimodels-prices/backend/cron/openrouter-api/update-other-price.go
wood chen dce4815654 新增定时任务功能和前端免责声明
- 在 go.mod 中添加 cron 库依赖
- 更新 main.go,初始化并启动定时任务
- 在 Home.vue 中添加免责声明部分,明确价格信息的准确性和更新方式
2025-03-18 01:40:39 +08:00

224 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openrouter_api
import (
"fmt"
"io"
"log"
"net/http"
"encoding/json"
"strings"
"time"
"aimodels-prices/database"
"aimodels-prices/models"
)
// 定义厂商ID映射
var authorToChannelType = map[string]uint{
"openai": 1,
"anthropic": 14,
"qwen": 17,
"google": 25,
"x-ai": 1001,
}
// 定义黑名单列表
var blacklist = []string{
"shap-e",
"palm-2",
"o3-mini-high",
}
const (
OtherPriceSource = "三方API"
OtherStatus = "pending"
)
// UpdateOtherPrices 更新其他厂商的价格
func UpdateOtherPrices() error {
log.Println("开始更新其他厂商价格数据...")
// 复用已有的API请求获取数据
resp, err := fetchOpenRouterData()
if err != nil {
return fmt.Errorf("获取OpenRouter数据失败: %v", err)
}
// 获取数据库连接
db := database.DB
if db == nil {
return fmt.Errorf("获取数据库连接失败")
}
// 处理每个模型的价格数据
processedCount := 0
skippedCount := 0
for _, modelData := range resp.Data {
// 提取模型名称slug中/后面的部分)
parts := strings.Split(modelData.Slug, "/")
if len(parts) < 2 {
log.Printf("跳过无效的模型名称: %s", modelData.Slug)
skippedCount++
continue
}
// 获取模型名称并去除":free"后缀
modelName := parts[1]
modelName = strings.Split(modelName, ":")[0]
// 检查是否在黑名单中
if isInBlacklist(modelName) {
log.Printf("跳过黑名单模型: %s", modelName)
skippedCount++
continue
}
// 获取作者名称
author := parts[0]
// 检查是否支持的厂商
channelType, ok := authorToChannelType[author]
if !ok {
log.Printf("跳过不支持的厂商: %s", author)
skippedCount++
continue
}
// 处理特殊模型名称
if author == "google" {
// 处理gemini-flash-1.5系列模型名称
if strings.HasPrefix(modelName, "gemini-flash-1.5") {
suffix := strings.TrimPrefix(modelName, "gemini-flash-1.5")
modelName = "gemini-1.5-flash" + suffix
log.Printf("修正Google模型名称: %s -> %s", parts[1], modelName)
}
}
// 确定模型类型
modelType := determineModelType(modelData.Modality)
// 解析价格
var inputPrice, outputPrice float64
var parseErr error
// 优先使用endpoint中的pricing
if modelData.Endpoint.Pricing.Prompt != "" {
inputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Prompt)
if parseErr != nil {
log.Printf("解析endpoint输入价格失败 %s: %v", modelData.Slug, parseErr)
skippedCount++
continue
}
} else if modelData.Pricing.Prompt != "" {
// 如果endpoint中没有则使用顶层pricing
inputPrice, parseErr = parsePrice(modelData.Pricing.Prompt)
if parseErr != nil {
log.Printf("解析输入价格失败 %s: %v", modelData.Slug, parseErr)
skippedCount++
continue
}
}
if modelData.Endpoint.Pricing.Completion != "" {
outputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Completion)
if parseErr != nil {
log.Printf("解析endpoint输出价格失败 %s: %v", modelData.Slug, parseErr)
skippedCount++
continue
}
} else if modelData.Pricing.Completion != "" {
outputPrice, parseErr = parsePrice(modelData.Pricing.Completion)
if parseErr != nil {
log.Printf("解析输出价格失败 %s: %v", modelData.Slug, parseErr)
skippedCount++
continue
}
}
// 检查是否已存在相同模型的价格记录
var existingPrice models.Price
result := db.Where("model = ? AND channel_type = ?", modelName, channelType).First(&existingPrice)
if result.Error == nil {
// 更新现有记录
existingPrice.ModelType = modelType
existingPrice.BillingType = BillingType
existingPrice.Currency = Currency
existingPrice.InputPrice = inputPrice
existingPrice.OutputPrice = outputPrice
existingPrice.PriceSource = OtherPriceSource
existingPrice.Status = OtherStatus
existingPrice.UpdatedAt = time.Now()
if err := db.Save(&existingPrice).Error; err != nil {
log.Printf("更新价格记录失败 %s: %v", modelName, err)
skippedCount++
continue
}
log.Printf("更新价格记录: %s (厂商: %s)", modelName, author)
processedCount++
} else {
// 创建新记录
newPrice := models.Price{
Model: modelName,
ModelType: modelType,
BillingType: BillingType,
ChannelType: channelType,
Currency: Currency,
InputPrice: inputPrice,
OutputPrice: outputPrice,
PriceSource: OtherPriceSource,
Status: OtherStatus,
CreatedBy: CreatedBy,
}
if err := db.Create(&newPrice).Error; err != nil {
log.Printf("创建价格记录失败 %s: %v", modelName, err)
skippedCount++
continue
}
log.Printf("创建新价格记录: %s (厂商: %s)", modelName, author)
processedCount++
}
}
log.Printf("其他厂商价格数据处理完成,成功处理: %d, 跳过: %d", processedCount, skippedCount)
return nil
}
// fetchOpenRouterData 获取OpenRouter API数据
func fetchOpenRouterData() (*OpenRouterResponse, error) {
// 复用已有的HTTP请求逻辑
resp, err := http.Get(OpenRouterAPIURL)
if err != nil {
return nil, fmt.Errorf("请求OpenRouter API失败: %v", err)
}
defer resp.Body.Close()
// 读取响应内容
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %v", err)
}
// 解析JSON数据
var openRouterResp OpenRouterResponse
if err := json.Unmarshal(body, &openRouterResp); err != nil {
return nil, fmt.Errorf("解析JSON数据失败: %v", err)
}
return &openRouterResp, nil
}
// isInBlacklist 检查模型名称是否在黑名单中
func isInBlacklist(modelName string) bool {
modelNameLower := strings.ToLower(modelName)
for _, blacklistItem := range blacklist {
if strings.Contains(modelNameLower, blacklistItem) {
return true
}
}
return false
}