diff --git a/backend/cron/main.go b/backend/cron/main.go new file mode 100644 index 0000000..070336c --- /dev/null +++ b/backend/cron/main.go @@ -0,0 +1,73 @@ +package cron + +import ( + "log" + "time" + + "github.com/robfig/cron/v3" + + openrouter_api "aimodels-prices/cron/openrouter-api" +) + +var cronScheduler *cron.Cron + +// InitCronJobs 初始化并启动所有定时任务 +func InitCronJobs() { + log.Println("初始化定时任务...") + + // 创建一个新的cron调度器,使用秒级精度 + cronScheduler = cron.New(cron.WithSeconds()) + + // 注册OpenRouter价格获取任务 + // 每24小时执行一次 + _, err := cronScheduler.AddFunc("0 0 0 * * *", func() { + if err := openrouter_api.FetchAndSavePrices(); err != nil { + log.Printf("OpenRouter价格获取任务执行失败: %v", err) + } + }) + + if err != nil { + log.Printf("注册OpenRouter价格获取任务失败: %v", err) + } + + // 注册其他厂商价格更新任务 + // 每24小时执行一次,错开时间避免同时执行 + _, err = cronScheduler.AddFunc("0 30 0 * * *", func() { + if err := openrouter_api.UpdateOtherPrices(); err != nil { + log.Printf("其他厂商价格更新任务执行失败: %v", err) + } + }) + + if err != nil { + log.Printf("注册其他厂商价格更新任务失败: %v", err) + } + + // 启动定时任务 + cronScheduler.Start() + log.Println("定时任务已启动") + + // 立即执行一次价格获取任务 + go func() { + // 等待几秒钟,确保应用程序和数据库已完全初始化 + time.Sleep(5 * time.Second) + log.Println("立即执行OpenRouter价格获取任务...") + if err := openrouter_api.FetchAndSavePrices(); err != nil { + log.Printf("初始OpenRouter价格获取任务执行失败: %v", err) + } + + // 等待几秒后执行其他厂商价格更新任务 + time.Sleep(3 * time.Second) + log.Println("立即执行其他厂商价格更新任务...") + if err := openrouter_api.UpdateOtherPrices(); err != nil { + log.Printf("初始其他厂商价格更新任务执行失败: %v", err) + } + }() +} + +// StopCronJobs 停止所有定时任务 +func StopCronJobs() { + if cronScheduler != nil { + cronScheduler.Stop() + log.Println("定时任务已停止") + } +} diff --git a/backend/cron/openrouter-api/openrouter-price.go b/backend/cron/openrouter-api/openrouter-price.go new file mode 100644 index 0000000..531f33d --- /dev/null +++ b/backend/cron/openrouter-api/openrouter-price.go @@ -0,0 +1,189 @@ +package openrouter_api + +import ( + "encoding/json" + "fmt" + "io" + "log" + "math" + "net/http" + "strconv" + "time" + + "aimodels-prices/database" + "aimodels-prices/models" +) + +const ( + OpenRouterAPIURL = "https://openrouter.ai/api/frontend/models" + ChannelType = 1002 + BillingType = "tokens" + Currency = "USD" + PriceSource = "https://openrouter.ai/models" + Status = "approved" + CreatedBy = "cron自动任务" +) + +type OpenRouterResponse struct { + Data []ModelData `json:"data"` +} + +type ModelData struct { + Slug string `json:"slug"` + Modality string `json:"modality"` + Pricing Pricing `json:"pricing"` + Endpoint Endpoint `json:"endpoint"` +} + +type Pricing struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` +} + +type Endpoint struct { + Pricing Pricing `json:"pricing"` +} + +// FetchAndSavePrices 获取OpenRouter API的价格并保存到数据库 +func FetchAndSavePrices() error { + log.Println("开始获取OpenRouter价格数据...") + + // 发送GET请求获取数据 + resp, err := http.Get(OpenRouterAPIURL) + if err != nil { + return fmt.Errorf("请求OpenRouter API失败: %v", err) + } + defer resp.Body.Close() + + // 读取响应内容 + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("读取响应内容失败: %v", err) + } + + // 解析JSON数据 + var openRouterResp OpenRouterResponse + if err := json.Unmarshal(body, &openRouterResp); err != nil { + return fmt.Errorf("解析JSON数据失败: %v", err) + } + + // 获取数据库连接 + db := database.DB + if db == nil { + return fmt.Errorf("获取数据库连接失败") + } + + // 处理每个模型的价格数据 + for _, modelData := range openRouterResp.Data { + // 确定模型类型 + modelType := determineModelType(modelData.Modality) + + // 使用endpoint中的pricing + var inputPrice, outputPrice float64 + var err error + + // 优先使用endpoint中的pricing + if modelData.Endpoint.Pricing.Prompt != "" { + inputPrice, err = parsePrice(modelData.Endpoint.Pricing.Prompt) + if err != nil { + log.Printf("解析endpoint输入价格失败 %s: %v", modelData.Slug, err) + continue + } + } else if modelData.Pricing.Prompt != "" { + // 如果endpoint中没有,则使用顶层pricing + inputPrice, err = parsePrice(modelData.Pricing.Prompt) + if err != nil { + log.Printf("解析输入价格失败 %s: %v", modelData.Slug, err) + continue + } + } + + if modelData.Endpoint.Pricing.Completion != "" { + outputPrice, err = parsePrice(modelData.Endpoint.Pricing.Completion) + if err != nil { + log.Printf("解析endpoint输出价格失败 %s: %v", modelData.Slug, err) + continue + } + } else if modelData.Pricing.Completion != "" { + outputPrice, err = parsePrice(modelData.Pricing.Completion) + if err != nil { + log.Printf("解析输出价格失败 %s: %v", modelData.Slug, err) + continue + } + } + + // 检查是否已存在相同模型的价格记录 + var existingPrice models.Price + result := db.Where("model = ? AND channel_type = ?", modelData.Slug, ChannelType).First(&existingPrice) + + if result.Error == nil { + // 更新现有记录 + existingPrice.ModelType = modelType + existingPrice.BillingType = BillingType + existingPrice.Currency = Currency + existingPrice.InputPrice = inputPrice + existingPrice.OutputPrice = outputPrice + existingPrice.PriceSource = PriceSource + existingPrice.Status = Status + existingPrice.UpdatedAt = time.Now() + + if err := db.Save(&existingPrice).Error; err != nil { + log.Printf("更新价格记录失败 %s: %v", modelData.Slug, err) + continue + } + log.Printf("更新价格记录: %s", modelData.Slug) + } else { + // 创建新记录 + newPrice := models.Price{ + Model: modelData.Slug, + ModelType: modelType, + BillingType: BillingType, + ChannelType: ChannelType, + Currency: Currency, + InputPrice: inputPrice, + OutputPrice: outputPrice, + PriceSource: PriceSource, + Status: Status, + CreatedBy: CreatedBy, + } + + if err := db.Create(&newPrice).Error; err != nil { + log.Printf("创建价格记录失败 %s: %v", modelData.Slug, err) + continue + } + log.Printf("创建新价格记录: %s", modelData.Slug) + } + } + + log.Println("OpenRouter价格数据处理完成") + return nil +} + +// determineModelType 根据modality确定模型类型 +func determineModelType(modality string) string { + switch modality { + case "text->text": + return "text2text" + case "text+image->text": + return "multimodal" + default: + return "other" + } +} + +// parsePrice 解析价格字符串为浮点数并乘以1000000 +func parsePrice(priceStr string) (float64, error) { + if priceStr == "" { + return 0, nil // 如果价格为空,返回0 + } + + price, err := strconv.ParseFloat(priceStr, 64) + if err != nil { + log.Printf("价格解析失败: %s, 错误: %v", priceStr, err) + return 0, err + } + + // 乘以1000000并四舍五入到6位小数,避免浮点数精度问题 + result := math.Round(price*1000000*1000000) / 1000000 + return result, nil +} diff --git a/backend/cron/openrouter-api/update-other-price.go b/backend/cron/openrouter-api/update-other-price.go new file mode 100644 index 0000000..81090af --- /dev/null +++ b/backend/cron/openrouter-api/update-other-price.go @@ -0,0 +1,223 @@ +package openrouter_api + +import ( + "fmt" + "io" + "log" + "net/http" + + "encoding/json" + "strings" + "time" + + "aimodels-prices/database" + "aimodels-prices/models" +) + +// 定义厂商ID映射 +var authorToChannelType = map[string]uint{ + "openai": 1, + "anthropic": 14, + "qwen": 17, + "google": 25, + "x-ai": 1001, +} + +// 定义黑名单列表 +var blacklist = []string{ + "shap-e", + "palm-2", + "o3-mini-high", +} + +const ( + OtherPriceSource = "三方API" + OtherStatus = "pending" +) + +// UpdateOtherPrices 更新其他厂商的价格 +func UpdateOtherPrices() error { + log.Println("开始更新其他厂商价格数据...") + + // 复用已有的API请求获取数据 + resp, err := fetchOpenRouterData() + if err != nil { + return fmt.Errorf("获取OpenRouter数据失败: %v", err) + } + + // 获取数据库连接 + db := database.DB + if db == nil { + return fmt.Errorf("获取数据库连接失败") + } + + // 处理每个模型的价格数据 + processedCount := 0 + skippedCount := 0 + for _, modelData := range resp.Data { + // 提取模型名称(slug中/后面的部分) + parts := strings.Split(modelData.Slug, "/") + if len(parts) < 2 { + log.Printf("跳过无效的模型名称: %s", modelData.Slug) + skippedCount++ + continue + } + + // 获取模型名称并去除":free"后缀 + modelName := parts[1] + modelName = strings.Split(modelName, ":")[0] + + // 检查是否在黑名单中 + if isInBlacklist(modelName) { + log.Printf("跳过黑名单模型: %s", modelName) + skippedCount++ + continue + } + + // 获取作者名称 + author := parts[0] + + // 检查是否支持的厂商 + channelType, ok := authorToChannelType[author] + if !ok { + log.Printf("跳过不支持的厂商: %s", author) + skippedCount++ + continue + } + + // 处理特殊模型名称 + if author == "google" { + // 处理gemini-flash-1.5系列模型名称 + if strings.HasPrefix(modelName, "gemini-flash-1.5") { + suffix := strings.TrimPrefix(modelName, "gemini-flash-1.5") + modelName = "gemini-1.5-flash" + suffix + log.Printf("修正Google模型名称: %s -> %s", parts[1], modelName) + } + } + + // 确定模型类型 + modelType := determineModelType(modelData.Modality) + + // 解析价格 + var inputPrice, outputPrice float64 + var parseErr error + + // 优先使用endpoint中的pricing + if modelData.Endpoint.Pricing.Prompt != "" { + inputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Prompt) + if parseErr != nil { + log.Printf("解析endpoint输入价格失败 %s: %v", modelData.Slug, parseErr) + skippedCount++ + continue + } + } else if modelData.Pricing.Prompt != "" { + // 如果endpoint中没有,则使用顶层pricing + inputPrice, parseErr = parsePrice(modelData.Pricing.Prompt) + if parseErr != nil { + log.Printf("解析输入价格失败 %s: %v", modelData.Slug, parseErr) + skippedCount++ + continue + } + } + + if modelData.Endpoint.Pricing.Completion != "" { + outputPrice, parseErr = parsePrice(modelData.Endpoint.Pricing.Completion) + if parseErr != nil { + log.Printf("解析endpoint输出价格失败 %s: %v", modelData.Slug, parseErr) + skippedCount++ + continue + } + } else if modelData.Pricing.Completion != "" { + outputPrice, parseErr = parsePrice(modelData.Pricing.Completion) + if parseErr != nil { + log.Printf("解析输出价格失败 %s: %v", modelData.Slug, parseErr) + skippedCount++ + continue + } + } + + // 检查是否已存在相同模型的价格记录 + var existingPrice models.Price + result := db.Where("model = ? AND channel_type = ?", modelName, channelType).First(&existingPrice) + + if result.Error == nil { + // 更新现有记录 + existingPrice.ModelType = modelType + existingPrice.BillingType = BillingType + existingPrice.Currency = Currency + existingPrice.InputPrice = inputPrice + existingPrice.OutputPrice = outputPrice + existingPrice.PriceSource = OtherPriceSource + existingPrice.Status = OtherStatus + existingPrice.UpdatedAt = time.Now() + + if err := db.Save(&existingPrice).Error; err != nil { + log.Printf("更新价格记录失败 %s: %v", modelName, err) + skippedCount++ + continue + } + log.Printf("更新价格记录: %s (厂商: %s)", modelName, author) + processedCount++ + } else { + // 创建新记录 + newPrice := models.Price{ + Model: modelName, + ModelType: modelType, + BillingType: BillingType, + ChannelType: channelType, + Currency: Currency, + InputPrice: inputPrice, + OutputPrice: outputPrice, + PriceSource: OtherPriceSource, + Status: OtherStatus, + CreatedBy: CreatedBy, + } + + if err := db.Create(&newPrice).Error; err != nil { + log.Printf("创建价格记录失败 %s: %v", modelName, err) + skippedCount++ + continue + } + log.Printf("创建新价格记录: %s (厂商: %s)", modelName, author) + processedCount++ + } + } + + log.Printf("其他厂商价格数据处理完成,成功处理: %d, 跳过: %d", processedCount, skippedCount) + return nil +} + +// fetchOpenRouterData 获取OpenRouter API数据 +func fetchOpenRouterData() (*OpenRouterResponse, error) { + // 复用已有的HTTP请求逻辑 + resp, err := http.Get(OpenRouterAPIURL) + if err != nil { + return nil, fmt.Errorf("请求OpenRouter API失败: %v", err) + } + defer resp.Body.Close() + + // 读取响应内容 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应内容失败: %v", err) + } + + // 解析JSON数据 + var openRouterResp OpenRouterResponse + if err := json.Unmarshal(body, &openRouterResp); err != nil { + return nil, fmt.Errorf("解析JSON数据失败: %v", err) + } + + return &openRouterResp, nil +} + +// isInBlacklist 检查模型名称是否在黑名单中 +func isInBlacklist(modelName string) bool { + modelNameLower := strings.ToLower(modelName) + for _, blacklistItem := range blacklist { + if strings.Contains(modelNameLower, blacklistItem) { + return true + } + } + return false +} diff --git a/backend/go.mod b/backend/go.mod index 6be0d07..0675ceb 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -7,6 +7,7 @@ toolchain go1.23.1 require ( github.com/gin-gonic/gin v1.9.1 github.com/joho/godotenv v1.5.1 + github.com/robfig/cron/v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/gorm v1.25.12 ) diff --git a/backend/go.sum b/backend/go.sum index a292d23..6b17547 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -56,6 +56,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/backend/main.go b/backend/main.go index 38e5b8a..3763a2b 100644 --- a/backend/main.go +++ b/backend/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "aimodels-prices/config" + "aimodels-prices/cron" "aimodels-prices/database" "aimodels-prices/handlers" "aimodels-prices/handlers/rates" @@ -30,6 +31,10 @@ func main() { gin.SetMode(gin.ReleaseMode) } + // 初始化并启动定时任务 + cron.InitCronJobs() + defer cron.StopCronJobs() + r := gin.Default() // CORS中间件 diff --git a/frontend/src/views/Home.vue b/frontend/src/views/Home.vue index a11958a..142c9dc 100644 --- a/frontend/src/views/Home.vue +++ b/frontend/src/views/Home.vue @@ -20,8 +20,13 @@

交流讨论

-

请在帖子下留言: https://q58.club/t/topic/277

+

请在帖子下留言: https://www.q58.club/t/topic/277

+

免责声明

+

+ 所有价格信息仅供参考,不保证100%准确性,及时性, 完整性, 价格由人工编辑+API自动更新, 所以可能有误差, 具体价格以实际为准。 +

+

API文档