diff --git a/backend/handlers/prices.go b/backend/handlers/prices.go index bab70e4..6b49aae 100644 --- a/backend/handlers/prices.go +++ b/backend/handlers/prices.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "aimodels-prices/database" + "aimodels-prices/handlers/rates" "aimodels-prices/models" ) @@ -355,97 +356,6 @@ func DeletePrice(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"}) } -// PriceRate 价格倍率结构 -type PriceRate struct { - Model string `json:"model"` - ModelType string `json:"model_type"` - Type string `json:"type"` - ChannelType uint `json:"channel_type"` - Input float64 `json:"input"` - Output float64 `json:"output"` -} - -// GetPriceRates 获取价格倍率 -func GetPriceRates(c *gin.Context) { - cacheKey := "price_rates" - - // 尝试从缓存获取 - if cachedData, found := database.GlobalCache.Get(cacheKey); found { - if rates, ok := cachedData.([]PriceRate); ok { - c.JSON(http.StatusOK, rates) - return - } - } - - // 使用索引优化查询,只查询需要的字段 - var prices []models.Price - if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price"). - Where("status = 'approved'"). - Find(&prices).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"}) - return - } - - // 按模型分组 - 使用map优化 - modelMap := make(map[string]map[uint]models.Price, len(prices)/2) // 预分配合理大小 - for _, price := range prices { - if _, exists := modelMap[price.Model]; !exists { - modelMap[price.Model] = make(map[uint]models.Price, 5) // 假设每个模型有5个提供商 - } - modelMap[price.Model][price.ChannelType] = price - } - - // 预分配rates切片,减少内存分配 - rates := make([]PriceRate, 0, len(prices)) - - // 计算倍率 - for model, providers := range modelMap { - // 找出基准价格(通常是OpenAI的价格) - var basePrice models.Price - var found bool - if baseProvider, exists := providers[1]; exists { // 直接检查ID为1的提供商 - basePrice = baseProvider - found = true - } - - if !found { - continue - } - - // 计算其他厂商相对于基准价格的倍率 - for channelType, price := range providers { - if channelType == 1 { - continue // 跳过基准价格 - } - - // 计算输入和输出的倍率 - inputRate := 0.0 - if basePrice.InputPrice > 0 { - inputRate = price.InputPrice / basePrice.InputPrice - } - - outputRate := 0.0 - if basePrice.OutputPrice > 0 { - outputRate = price.OutputPrice / basePrice.OutputPrice - } - - rates = append(rates, PriceRate{ - Model: model, - ModelType: price.ModelType, - Type: price.BillingType, - ChannelType: channelType, - Input: inputRate, - Output: outputRate, - }) - } - } - - // 存入缓存,有效期10分钟 - database.GlobalCache.Set(cacheKey, rates, 10*time.Minute) - - c.JSON(http.StatusOK, rates) -} - func ApproveAllPrices(c *gin.Context) { // 查找所有待审核的价格 var pendingPrices []models.Price @@ -527,4 +437,7 @@ func ApproveAllPrices(c *gin.Context) { func clearPriceCache() { // 由于我们无法精确知道哪些缓存键与价格相关,所以清除所有缓存 database.GlobalCache.Clear() + + // 同时清除价格倍率缓存 + rates.ClearRatesCache() } diff --git a/backend/handlers/rates/price_rates.go b/backend/handlers/rates/price_rates.go new file mode 100644 index 0000000..f45a344 --- /dev/null +++ b/backend/handlers/rates/price_rates.go @@ -0,0 +1,81 @@ +package rates + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + + "aimodels-prices/database" + "aimodels-prices/models" +) + +// PriceRate 价格倍率结构 +type PriceRate struct { + Model string `json:"model"` + ModelType string `json:"model_type"` + Type string `json:"type"` + ChannelType uint `json:"channel_type"` + Input float64 `json:"input"` + Output float64 `json:"output"` +} + +// GetPriceRates 获取价格倍率 +func GetPriceRates(c *gin.Context) { + cacheKey := "price_rates" + + // 尝试从缓存获取 + if cachedData, found := database.GlobalCache.Get(cacheKey); found { + if rates, ok := cachedData.([]PriceRate); ok { + c.JSON(http.StatusOK, rates) + return + } + } + + // 使用索引优化查询,只查询需要的字段 + var prices []models.Price + if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price, currency, status"). + Where("status = 'approved'"). + Find(&prices).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"}) + return + } + + // 预分配rates切片,减少内存分配 + rates := make([]PriceRate, 0, len(prices)) + + // 计算倍率 + for _, price := range prices { + // 根据货币类型计算倍率 + var inputRate, outputRate float64 + + if price.Currency == "USD" { + // 如果是美元,除以2 + inputRate = price.InputPrice / 2 + outputRate = price.OutputPrice / 2 + } else { + // 如果是人民币或其他货币,除以14 + inputRate = price.InputPrice / 14 + outputRate = price.OutputPrice / 14 + } + + rates = append(rates, PriceRate{ + Model: price.Model, + ModelType: price.ModelType, + Type: price.BillingType, + ChannelType: price.ChannelType, + Input: inputRate, + Output: outputRate, + }) + } + + // 存入缓存,有效期24小时 + database.GlobalCache.Set(cacheKey, rates, 24*time.Hour) + + c.JSON(http.StatusOK, rates) +} + +// ClearRatesCache 清除价格倍率缓存 +func ClearRatesCache() { + database.GlobalCache.Delete("price_rates") +} diff --git a/backend/main.go b/backend/main.go index f4f402e..38e5b8a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -9,6 +9,7 @@ import ( "aimodels-prices/config" "aimodels-prices/database" "aimodels-prices/handlers" + "aimodels-prices/handlers/rates" "aimodels-prices/middleware" ) @@ -56,7 +57,7 @@ func main() { prices := api.Group("/prices") { prices.GET("", handlers.GetPrices) - prices.GET("/rates", handlers.GetPriceRates) + prices.GET("/rates", rates.GetPriceRates) prices.POST("", middleware.AuthRequired(), handlers.CreatePrice) prices.PUT("/:id", middleware.AuthRequired(), handlers.UpdatePrice) prices.DELETE("/:id", middleware.AuthRequired(), handlers.DeletePrice)