mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 21:51:59 +08:00
335 lines
9.3 KiB
Go
335 lines
9.3 KiB
Go
package siliconflow_api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"math"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"aimodels-prices/database"
|
||
"aimodels-prices/handlers"
|
||
"aimodels-prices/handlers/one_hub"
|
||
"aimodels-prices/models"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// 常量定义
|
||
const (
|
||
SiliconFlowChannelType = 45 // SiliconFlow的厂商ID
|
||
SiliconFlowAPIEndpoint = "/api/v1/playground/comprehensive/all"
|
||
SiliconFlowAPIHost = "busy-bear.siliconflow.cn"
|
||
PriceSource = "SiliconFlow API"
|
||
Status = "approved" // 设置为approved状态
|
||
CreatedBy = "cron自动任务"
|
||
Currency = "CNY" // 使用人民币
|
||
)
|
||
|
||
// 计费类型常量
|
||
const (
|
||
BillingTypeTokens = "tokens" // 基于token的计费方式
|
||
BillingTypeTimes = "times" // 基于次数的计费方式
|
||
)
|
||
|
||
// 定义API响应结构
|
||
type SiliconFlowResponse struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Status bool `json:"status"`
|
||
Data struct {
|
||
Models []SiliconFlowModel `json:"models"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
// 模型信息结构
|
||
type SiliconFlowModel struct {
|
||
ModelId string `json:"modelId"`
|
||
ModelName string `json:"modelName"`
|
||
DisplayName string `json:"DisplayName"`
|
||
Mf string `json:"mf"`
|
||
Desc string `json:"desc"`
|
||
Tags []string `json:"tags"`
|
||
Icon string `json:"icon"`
|
||
Size int `json:"size"`
|
||
ContextLen int `json:"contextLen"`
|
||
Price string `json:"price"`
|
||
Currency string `json:"currency"`
|
||
PriceUnit string `json:"priceUnit"`
|
||
Status string `json:"status"`
|
||
Type string `json:"type"`
|
||
SubType string `json:"subType"`
|
||
JsonModeSupport bool `json:"jsonModeSupport"`
|
||
FunctionCallSupport bool `json:"functionCallSupport"`
|
||
}
|
||
|
||
// UpdateSiliconFlowPrices 更新SiliconFlow模型价格
|
||
func UpdateSiliconFlowPrices() error {
|
||
log.Println("开始更新SiliconFlow价格数据...")
|
||
|
||
// 获取API数据
|
||
modelData, err := fetchSiliconFlowData()
|
||
if err != nil {
|
||
return fmt.Errorf("获取SiliconFlow数据失败: %v", err)
|
||
}
|
||
|
||
// 获取数据库连接
|
||
db := database.DB
|
||
if db == nil {
|
||
return fmt.Errorf("获取数据库连接失败")
|
||
}
|
||
|
||
// 处理每个模型的价格数据
|
||
processedCount := 0
|
||
skippedCount := 0
|
||
|
||
// 创建一个集合用于跟踪已处理的模型,避免重复
|
||
processedModels := make(map[string]bool)
|
||
|
||
for _, model := range modelData {
|
||
modelName := model.ModelName
|
||
|
||
// 检查是否已处理过这个模型
|
||
if processedModels[modelName] {
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 标记此模型为已处理
|
||
processedModels[modelName] = true
|
||
|
||
// 解析价格
|
||
modelPrice, err := strconv.ParseFloat(model.Price, 64)
|
||
if err != nil {
|
||
log.Printf("解析价格失败 %s: %v", modelName, err)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 确定模型类型和价格
|
||
var modelType string
|
||
var billingType string
|
||
var inputPrice, outputPrice float64
|
||
|
||
// 根据模型类型和价格单位确定模型类型和价格计算方式
|
||
switch {
|
||
case isTokenBasedUnit(model.PriceUnit):
|
||
// 基于Token的模型(如文本模型)
|
||
modelType = determineModelTypeBySubType(model.Type, model.SubType)
|
||
billingType = BillingTypeTokens // 使用tokens计费类型
|
||
// 直接使用价格,系统已经按每百万token为单位
|
||
inputPrice = roundPrice(modelPrice)
|
||
outputPrice = inputPrice // 使用相同价格
|
||
case isTimeBasedUnit(model.PriceUnit, model.Type):
|
||
// 基于次数的模型(如图像、视频)
|
||
modelType = determineModelTypeBySubType(model.Type, model.SubType)
|
||
billingType = BillingTypeTimes // 使用times计费类型
|
||
// 直接使用价格
|
||
inputPrice = roundPrice(modelPrice)
|
||
outputPrice = inputPrice // 使用相同价格
|
||
default:
|
||
// 默认按token计费
|
||
modelType = determineModelTypeBySubType(model.Type, model.SubType)
|
||
// 根据模型类型决定计费方式
|
||
if modelType == "text2image" || modelType == "text2video" || modelType == "image2video" {
|
||
billingType = BillingTypeTimes // 图像和视频相关模型使用times
|
||
} else {
|
||
billingType = BillingTypeTokens // 其他默认使用tokens
|
||
}
|
||
// 对于未知类型,默认按token处理
|
||
inputPrice = roundPrice(modelPrice)
|
||
outputPrice = inputPrice // 使用相同价格
|
||
log.Printf("未识别的价格单位: %s,默认使用计费类型: %s", model.PriceUnit, billingType)
|
||
}
|
||
|
||
// 创建价格对象
|
||
price := models.Price{
|
||
Model: modelName,
|
||
ModelType: modelType,
|
||
BillingType: billingType, // 使用动态确定的计费类型
|
||
ChannelType: SiliconFlowChannelType,
|
||
Currency: Currency, // 使用人民币
|
||
InputPrice: inputPrice,
|
||
OutputPrice: outputPrice,
|
||
PriceSource: PriceSource,
|
||
Status: Status, // 使用approved状态
|
||
CreatedBy: CreatedBy,
|
||
}
|
||
|
||
// 检查是否已存在相同模型的价格记录
|
||
var existingPrice models.Price
|
||
// 使用静默查询,不输出"record not found"错误
|
||
result := db.Where("model = ? AND channel_type = ?", modelName, SiliconFlowChannelType).First(&existingPrice)
|
||
|
||
if result.Error == nil {
|
||
// 记录存在,执行更新
|
||
_, changed, err := handlers.ProcessPrice(price, &existingPrice, true, CreatedBy)
|
||
if err != nil {
|
||
log.Printf("更新价格记录失败 %s: %v", modelName, err)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
if changed {
|
||
processedCount++
|
||
} else {
|
||
skippedCount++
|
||
}
|
||
} else if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||
// 记录不存在,需要创建新记录
|
||
// 检查是否存在相同模型名称的待审核记录
|
||
var pendingCount int64
|
||
if err := db.Model(&models.Price{}).Where("model = ? AND channel_type = ? AND status = 'pending'",
|
||
modelName, SiliconFlowChannelType).Count(&pendingCount).Error; err != nil {
|
||
log.Printf("检查待审核记录失败 %s: %v", modelName, err)
|
||
}
|
||
|
||
if pendingCount > 0 {
|
||
log.Printf("已存在待审核的相同模型记录,跳过创建: %s", modelName)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
// 创建新记录
|
||
_, changed, err := handlers.ProcessPrice(price, nil, true, CreatedBy)
|
||
if err != nil {
|
||
log.Printf("创建价格记录失败 %s: %v", modelName, err)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
|
||
if changed {
|
||
processedCount++
|
||
} else {
|
||
log.Printf("价格创建失败: %s", modelName)
|
||
skippedCount++
|
||
}
|
||
} else {
|
||
// 其他错误
|
||
log.Printf("查询价格记录时发生错误 %s: %v", modelName, result.Error)
|
||
skippedCount++
|
||
continue
|
||
}
|
||
}
|
||
|
||
log.Printf("SiliconFlow价格数据处理完成,成功处理: %d, 跳过: %d", processedCount, skippedCount)
|
||
|
||
// 清除倍率缓存
|
||
one_hub.ClearRatesCache()
|
||
log.Println("倍率缓存已清除")
|
||
return nil
|
||
}
|
||
|
||
// roundPrice 对价格进行四舍五入处理,保留6位小数
|
||
func roundPrice(price float64) float64 {
|
||
// 保留6位小数
|
||
return math.Round(price*1000000) / 1000000
|
||
}
|
||
|
||
// fetchSiliconFlowData 获取SiliconFlow API数据
|
||
func fetchSiliconFlowData() ([]SiliconFlowModel, error) {
|
||
apiKey := os.Getenv("SILICONFLOW_API_KEY")
|
||
if apiKey == "" {
|
||
return nil, fmt.Errorf("环境变量SILICONFLOW_API_KEY未设置")
|
||
}
|
||
|
||
// 创建HTTPS连接
|
||
conn, err := http.NewRequest("GET", "https://"+SiliconFlowAPIHost+SiliconFlowAPIEndpoint, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建HTTP请求失败: %v", err)
|
||
}
|
||
|
||
// 设置请求头
|
||
conn.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||
|
||
// 发送请求
|
||
client := &http.Client{}
|
||
resp, err := client.Do(conn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求SiliconFlow API失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应内容
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应内容失败: %v", err)
|
||
}
|
||
|
||
// 解析JSON数据
|
||
var siliconFlowResp SiliconFlowResponse
|
||
if err := json.Unmarshal(body, &siliconFlowResp); err != nil {
|
||
return nil, fmt.Errorf("解析JSON数据失败: %v", err)
|
||
}
|
||
|
||
// 检查响应状态
|
||
if !siliconFlowResp.Status || siliconFlowResp.Code != 20000 {
|
||
return nil, fmt.Errorf("API请求返回错误: %s", siliconFlowResp.Message)
|
||
}
|
||
|
||
return siliconFlowResp.Data.Models, nil
|
||
}
|
||
|
||
// isTokenBasedUnit 判断是否是基于token的计费单位
|
||
func isTokenBasedUnit(unit string) bool {
|
||
tokenUnits := []string{
|
||
"/ M Tokens",
|
||
"/ M UTF-8 bytes",
|
||
"/ M px / Steps",
|
||
}
|
||
|
||
for _, tokenUnit := range tokenUnits {
|
||
if strings.Contains(unit, tokenUnit) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// isTimeBasedUnit 判断是否是基于次数的计费单位
|
||
func isTimeBasedUnit(unit string, modelType string) bool {
|
||
timeUnits := []string{
|
||
"/ Video",
|
||
"/ Image",
|
||
"",
|
||
}
|
||
|
||
// 如果模型类型是视频或图像,即使价格单位为空也按次数计费
|
||
if modelType == "video" || modelType == "image" {
|
||
return true
|
||
}
|
||
|
||
for _, timeUnit := range timeUnits {
|
||
if strings.Contains(unit, timeUnit) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// determineModelTypeBySubType 根据模型类型和子类型确定我们系统中的模型类型
|
||
func determineModelTypeBySubType(modelType string, subType string) string {
|
||
switch modelType {
|
||
case "text":
|
||
return "text2text"
|
||
case "image":
|
||
return "text2image"
|
||
case "video":
|
||
if subType == "image-to-video" {
|
||
return "image2video"
|
||
}
|
||
return "text2video"
|
||
case "audio":
|
||
return "text2speech"
|
||
case "embedding":
|
||
return "embedding"
|
||
default:
|
||
return "other"
|
||
}
|
||
}
|