mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 05:32:00 +08:00
重构价格倍率处理逻辑,提取独立模块
- 将价格倍率处理函数从 handlers/prices.go 移动到新的 handlers/rates 包 - 更新 main.go 中的路由配置,使用新的 rates.GetPriceRates 处理函数 - 在 prices.go 中新增 clearPriceCache 时调用 rates.ClearRatesCache - 模块化价格倍率计算逻辑,提高代码组织性和可维护性
This commit is contained in:
parent
e037eaafef
commit
680d684016
@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"aimodels-prices/database"
|
||||
"aimodels-prices/handlers/rates"
|
||||
"aimodels-prices/models"
|
||||
)
|
||||
|
||||
@ -355,97 +356,6 @@ func DeletePrice(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"})
|
||||
}
|
||||
|
||||
// PriceRate 价格倍率结构
|
||||
type PriceRate struct {
|
||||
Model string `json:"model"`
|
||||
ModelType string `json:"model_type"`
|
||||
Type string `json:"type"`
|
||||
ChannelType uint `json:"channel_type"`
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
}
|
||||
|
||||
// GetPriceRates 获取价格倍率
|
||||
func GetPriceRates(c *gin.Context) {
|
||||
cacheKey := "price_rates"
|
||||
|
||||
// 尝试从缓存获取
|
||||
if cachedData, found := database.GlobalCache.Get(cacheKey); found {
|
||||
if rates, ok := cachedData.([]PriceRate); ok {
|
||||
c.JSON(http.StatusOK, rates)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用索引优化查询,只查询需要的字段
|
||||
var prices []models.Price
|
||||
if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price").
|
||||
Where("status = 'approved'").
|
||||
Find(&prices).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
|
||||
return
|
||||
}
|
||||
|
||||
// 按模型分组 - 使用map优化
|
||||
modelMap := make(map[string]map[uint]models.Price, len(prices)/2) // 预分配合理大小
|
||||
for _, price := range prices {
|
||||
if _, exists := modelMap[price.Model]; !exists {
|
||||
modelMap[price.Model] = make(map[uint]models.Price, 5) // 假设每个模型有5个提供商
|
||||
}
|
||||
modelMap[price.Model][price.ChannelType] = price
|
||||
}
|
||||
|
||||
// 预分配rates切片,减少内存分配
|
||||
rates := make([]PriceRate, 0, len(prices))
|
||||
|
||||
// 计算倍率
|
||||
for model, providers := range modelMap {
|
||||
// 找出基准价格(通常是OpenAI的价格)
|
||||
var basePrice models.Price
|
||||
var found bool
|
||||
if baseProvider, exists := providers[1]; exists { // 直接检查ID为1的提供商
|
||||
basePrice = baseProvider
|
||||
found = true
|
||||
}
|
||||
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算其他厂商相对于基准价格的倍率
|
||||
for channelType, price := range providers {
|
||||
if channelType == 1 {
|
||||
continue // 跳过基准价格
|
||||
}
|
||||
|
||||
// 计算输入和输出的倍率
|
||||
inputRate := 0.0
|
||||
if basePrice.InputPrice > 0 {
|
||||
inputRate = price.InputPrice / basePrice.InputPrice
|
||||
}
|
||||
|
||||
outputRate := 0.0
|
||||
if basePrice.OutputPrice > 0 {
|
||||
outputRate = price.OutputPrice / basePrice.OutputPrice
|
||||
}
|
||||
|
||||
rates = append(rates, PriceRate{
|
||||
Model: model,
|
||||
ModelType: price.ModelType,
|
||||
Type: price.BillingType,
|
||||
ChannelType: channelType,
|
||||
Input: inputRate,
|
||||
Output: outputRate,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 存入缓存,有效期10分钟
|
||||
database.GlobalCache.Set(cacheKey, rates, 10*time.Minute)
|
||||
|
||||
c.JSON(http.StatusOK, rates)
|
||||
}
|
||||
|
||||
func ApproveAllPrices(c *gin.Context) {
|
||||
// 查找所有待审核的价格
|
||||
var pendingPrices []models.Price
|
||||
@ -527,4 +437,7 @@ func ApproveAllPrices(c *gin.Context) {
|
||||
func clearPriceCache() {
|
||||
// 由于我们无法精确知道哪些缓存键与价格相关,所以清除所有缓存
|
||||
database.GlobalCache.Clear()
|
||||
|
||||
// 同时清除价格倍率缓存
|
||||
rates.ClearRatesCache()
|
||||
}
|
||||
|
81
backend/handlers/rates/price_rates.go
Normal file
81
backend/handlers/rates/price_rates.go
Normal file
@ -0,0 +1,81 @@
|
||||
package rates
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"aimodels-prices/database"
|
||||
"aimodels-prices/models"
|
||||
)
|
||||
|
||||
// PriceRate 价格倍率结构
|
||||
type PriceRate struct {
|
||||
Model string `json:"model"`
|
||||
ModelType string `json:"model_type"`
|
||||
Type string `json:"type"`
|
||||
ChannelType uint `json:"channel_type"`
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
}
|
||||
|
||||
// GetPriceRates 获取价格倍率
|
||||
func GetPriceRates(c *gin.Context) {
|
||||
cacheKey := "price_rates"
|
||||
|
||||
// 尝试从缓存获取
|
||||
if cachedData, found := database.GlobalCache.Get(cacheKey); found {
|
||||
if rates, ok := cachedData.([]PriceRate); ok {
|
||||
c.JSON(http.StatusOK, rates)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用索引优化查询,只查询需要的字段
|
||||
var prices []models.Price
|
||||
if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price, currency, status").
|
||||
Where("status = 'approved'").
|
||||
Find(&prices).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
|
||||
return
|
||||
}
|
||||
|
||||
// 预分配rates切片,减少内存分配
|
||||
rates := make([]PriceRate, 0, len(prices))
|
||||
|
||||
// 计算倍率
|
||||
for _, price := range prices {
|
||||
// 根据货币类型计算倍率
|
||||
var inputRate, outputRate float64
|
||||
|
||||
if price.Currency == "USD" {
|
||||
// 如果是美元,除以2
|
||||
inputRate = price.InputPrice / 2
|
||||
outputRate = price.OutputPrice / 2
|
||||
} else {
|
||||
// 如果是人民币或其他货币,除以14
|
||||
inputRate = price.InputPrice / 14
|
||||
outputRate = price.OutputPrice / 14
|
||||
}
|
||||
|
||||
rates = append(rates, PriceRate{
|
||||
Model: price.Model,
|
||||
ModelType: price.ModelType,
|
||||
Type: price.BillingType,
|
||||
ChannelType: price.ChannelType,
|
||||
Input: inputRate,
|
||||
Output: outputRate,
|
||||
})
|
||||
}
|
||||
|
||||
// 存入缓存,有效期24小时
|
||||
database.GlobalCache.Set(cacheKey, rates, 24*time.Hour)
|
||||
|
||||
c.JSON(http.StatusOK, rates)
|
||||
}
|
||||
|
||||
// ClearRatesCache 清除价格倍率缓存
|
||||
func ClearRatesCache() {
|
||||
database.GlobalCache.Delete("price_rates")
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
"aimodels-prices/config"
|
||||
"aimodels-prices/database"
|
||||
"aimodels-prices/handlers"
|
||||
"aimodels-prices/handlers/rates"
|
||||
"aimodels-prices/middleware"
|
||||
)
|
||||
|
||||
@ -56,7 +57,7 @@ func main() {
|
||||
prices := api.Group("/prices")
|
||||
{
|
||||
prices.GET("", handlers.GetPrices)
|
||||
prices.GET("/rates", handlers.GetPriceRates)
|
||||
prices.GET("/rates", rates.GetPriceRates)
|
||||
prices.POST("", middleware.AuthRequired(), handlers.CreatePrice)
|
||||
prices.PUT("/:id", middleware.AuthRequired(), handlers.UpdatePrice)
|
||||
prices.DELETE("/:id", middleware.AuthRequired(), handlers.DeletePrice)
|
||||
|
Loading…
x
Reference in New Issue
Block a user