重构价格倍率处理逻辑,提取独立模块

- 将价格倍率处理函数从 handlers/prices.go 移动到新的 handlers/rates 包
- 更新 main.go 中的路由配置,使用新的 rates.GetPriceRates 处理函数
- 在 prices.go 中新增 clearPriceCache 时调用 rates.ClearRatesCache
- 模块化价格倍率计算逻辑,提高代码组织性和可维护性
This commit is contained in:
wood chen 2025-03-12 16:35:15 +08:00
parent e037eaafef
commit 680d684016
3 changed files with 87 additions and 92 deletions

View File

@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database" "aimodels-prices/database"
"aimodels-prices/handlers/rates"
"aimodels-prices/models" "aimodels-prices/models"
) )
@ -355,97 +356,6 @@ func DeletePrice(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"}) c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"})
} }
// PriceRate 价格倍率结构
type PriceRate struct {
Model string `json:"model"`
ModelType string `json:"model_type"`
Type string `json:"type"`
ChannelType uint `json:"channel_type"`
Input float64 `json:"input"`
Output float64 `json:"output"`
}
// GetPriceRates 获取价格倍率
func GetPriceRates(c *gin.Context) {
cacheKey := "price_rates"
// 尝试从缓存获取
if cachedData, found := database.GlobalCache.Get(cacheKey); found {
if rates, ok := cachedData.([]PriceRate); ok {
c.JSON(http.StatusOK, rates)
return
}
}
// 使用索引优化查询,只查询需要的字段
var prices []models.Price
if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price").
Where("status = 'approved'").
Find(&prices).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
return
}
// 按模型分组 - 使用map优化
modelMap := make(map[string]map[uint]models.Price, len(prices)/2) // 预分配合理大小
for _, price := range prices {
if _, exists := modelMap[price.Model]; !exists {
modelMap[price.Model] = make(map[uint]models.Price, 5) // 假设每个模型有5个提供商
}
modelMap[price.Model][price.ChannelType] = price
}
// 预分配rates切片减少内存分配
rates := make([]PriceRate, 0, len(prices))
// 计算倍率
for model, providers := range modelMap {
// 找出基准价格通常是OpenAI的价格
var basePrice models.Price
var found bool
if baseProvider, exists := providers[1]; exists { // 直接检查ID为1的提供商
basePrice = baseProvider
found = true
}
if !found {
continue
}
// 计算其他厂商相对于基准价格的倍率
for channelType, price := range providers {
if channelType == 1 {
continue // 跳过基准价格
}
// 计算输入和输出的倍率
inputRate := 0.0
if basePrice.InputPrice > 0 {
inputRate = price.InputPrice / basePrice.InputPrice
}
outputRate := 0.0
if basePrice.OutputPrice > 0 {
outputRate = price.OutputPrice / basePrice.OutputPrice
}
rates = append(rates, PriceRate{
Model: model,
ModelType: price.ModelType,
Type: price.BillingType,
ChannelType: channelType,
Input: inputRate,
Output: outputRate,
})
}
}
// 存入缓存有效期10分钟
database.GlobalCache.Set(cacheKey, rates, 10*time.Minute)
c.JSON(http.StatusOK, rates)
}
func ApproveAllPrices(c *gin.Context) { func ApproveAllPrices(c *gin.Context) {
// 查找所有待审核的价格 // 查找所有待审核的价格
var pendingPrices []models.Price var pendingPrices []models.Price
@ -527,4 +437,7 @@ func ApproveAllPrices(c *gin.Context) {
func clearPriceCache() { func clearPriceCache() {
// 由于我们无法精确知道哪些缓存键与价格相关,所以清除所有缓存 // 由于我们无法精确知道哪些缓存键与价格相关,所以清除所有缓存
database.GlobalCache.Clear() database.GlobalCache.Clear()
// 同时清除价格倍率缓存
rates.ClearRatesCache()
} }

View File

@ -0,0 +1,81 @@
package rates
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models"
)
// PriceRate 价格倍率结构
type PriceRate struct {
Model string `json:"model"`
ModelType string `json:"model_type"`
Type string `json:"type"`
ChannelType uint `json:"channel_type"`
Input float64 `json:"input"`
Output float64 `json:"output"`
}
// GetPriceRates 获取价格倍率
func GetPriceRates(c *gin.Context) {
cacheKey := "price_rates"
// 尝试从缓存获取
if cachedData, found := database.GlobalCache.Get(cacheKey); found {
if rates, ok := cachedData.([]PriceRate); ok {
c.JSON(http.StatusOK, rates)
return
}
}
// 使用索引优化查询,只查询需要的字段
var prices []models.Price
if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price, currency, status").
Where("status = 'approved'").
Find(&prices).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
return
}
// 预分配rates切片减少内存分配
rates := make([]PriceRate, 0, len(prices))
// 计算倍率
for _, price := range prices {
// 根据货币类型计算倍率
var inputRate, outputRate float64
if price.Currency == "USD" {
// 如果是美元除以2
inputRate = price.InputPrice / 2
outputRate = price.OutputPrice / 2
} else {
// 如果是人民币或其他货币除以14
inputRate = price.InputPrice / 14
outputRate = price.OutputPrice / 14
}
rates = append(rates, PriceRate{
Model: price.Model,
ModelType: price.ModelType,
Type: price.BillingType,
ChannelType: price.ChannelType,
Input: inputRate,
Output: outputRate,
})
}
// 存入缓存有效期24小时
database.GlobalCache.Set(cacheKey, rates, 24*time.Hour)
c.JSON(http.StatusOK, rates)
}
// ClearRatesCache 清除价格倍率缓存
func ClearRatesCache() {
database.GlobalCache.Delete("price_rates")
}

View File

@ -9,6 +9,7 @@ import (
"aimodels-prices/config" "aimodels-prices/config"
"aimodels-prices/database" "aimodels-prices/database"
"aimodels-prices/handlers" "aimodels-prices/handlers"
"aimodels-prices/handlers/rates"
"aimodels-prices/middleware" "aimodels-prices/middleware"
) )
@ -56,7 +57,7 @@ func main() {
prices := api.Group("/prices") prices := api.Group("/prices")
{ {
prices.GET("", handlers.GetPrices) prices.GET("", handlers.GetPrices)
prices.GET("/rates", handlers.GetPriceRates) prices.GET("/rates", rates.GetPriceRates)
prices.POST("", middleware.AuthRequired(), handlers.CreatePrice) prices.POST("", middleware.AuthRequired(), handlers.CreatePrice)
prices.PUT("/:id", middleware.AuthRequired(), handlers.UpdatePrice) prices.PUT("/:id", middleware.AuthRequired(), handlers.UpdatePrice)
prices.DELETE("/:id", middleware.AuthRequired(), handlers.DeletePrice) prices.DELETE("/:id", middleware.AuthRequired(), handlers.DeletePrice)