From 9f51ac602e4af7756aa5dc13a19f10263f463d30 Mon Sep 17 00:00:00 2001 From: wood chen Date: Fri, 7 Mar 2025 00:28:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=86=85=E5=AD=98=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=9C=BA=E5=88=B6=EF=BC=8C=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增内存缓存接口和实现,支持设置过期时间 - 在数据库初始化时创建全局缓存实例 - 为模型类型、提供商和价格查询添加缓存层 - 实现定期缓存常用数据的后台任务 - 优化数据库查询,减少重复查询和不必要的数据库访问 - 为价格查询添加索引,提高查询效率 --- backend/database/db.go | 335 +++++++++++++++++++++++++++++- backend/handlers/model_type.go | 67 +++--- backend/handlers/prices.go | 109 ++++++++-- backend/handlers/providers.go | 22 ++ backend/models/price.go | 12 +- frontend/src/views/ModelTypes.vue | 52 ++++- frontend/src/views/Prices.vue | 63 +++++- frontend/src/views/Providers.vue | 46 +++- frontend/vite.config.js | 2 +- 9 files changed, 618 insertions(+), 90 deletions(-) diff --git a/backend/database/db.go b/backend/database/db.go index 5f79a77..ee5cceb 100644 --- a/backend/database/db.go +++ b/backend/database/db.go @@ -3,6 +3,8 @@ package database import ( "fmt" "log" + "sync" + "time" "gorm.io/driver/mysql" "gorm.io/gorm" @@ -15,6 +17,116 @@ import ( // DB 是数据库连接的全局实例 var DB *gorm.DB +// Cache 接口定义了缓存的基本操作 +type Cache interface { + Get(key string) (interface{}, bool) + Set(key string, value interface{}, expiration time.Duration) + Delete(key string) + Clear() +} + +// MemoryCache 是一个简单的内存缓存实现 +type MemoryCache struct { + items map[string]cacheItem + mu sync.RWMutex +} + +type cacheItem struct { + value interface{} + expiration int64 +} + +// 全局缓存实例 +var GlobalCache Cache + +// NewMemoryCache 创建一个新的内存缓存 +func NewMemoryCache() *MemoryCache { + cache := &MemoryCache{ + items: make(map[string]cacheItem), + } + + // 启动一个后台协程定期清理过期项 + go cache.janitor() + + return cache +} + +// Get 从缓存中获取值 +func (c *MemoryCache) Get(key string) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + item, found := c.items[key] + if !found { + return nil, false + } + + // 检查是否过期 + if item.expiration > 0 && item.expiration < time.Now().UnixNano() { + return nil, false + } + + return item.value, true +} + +// Set 设置缓存值 +func (c *MemoryCache) Set(key string, value interface{}, expiration time.Duration) { + var exp int64 + + if expiration > 0 { + exp = time.Now().Add(expiration).UnixNano() + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.items[key] = cacheItem{ + value: value, + expiration: exp, + } +} + +// Delete 删除缓存项 +func (c *MemoryCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.items, key) +} + +// Clear 清空所有缓存 +func (c *MemoryCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[string]cacheItem) +} + +// janitor 定期清理过期的缓存项 +func (c *MemoryCache) janitor() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + <-ticker.C + c.deleteExpired() + } +} + +// deleteExpired 删除所有过期的项 +func (c *MemoryCache) deleteExpired() { + now := time.Now().UnixNano() + + c.mu.Lock() + defer c.mu.Unlock() + + for k, v := range c.items { + if v.expiration > 0 && v.expiration < now { + delete(c.items, k) + } + } +} + // InitDB 初始化数据库连接 func InitDB(cfg *config.Config) error { var err error @@ -43,8 +155,15 @@ func InitDB(cfg *config.Config) error { } // 设置连接池参数 - sqlDB.SetMaxOpenConns(10) - sqlDB.SetMaxIdleConns(5) + sqlDB.SetMaxOpenConns(20) // 增加最大连接数 + sqlDB.SetMaxIdleConns(10) // 增加空闲连接数 + sqlDB.SetConnMaxLifetime(time.Hour) // 设置连接最大生命周期 + + // 初始化缓存 + GlobalCache = NewMemoryCache() + + // 启动定期缓存任务 + go startCacheJobs() // 自动迁移表结构 if err = migrateModels(); err != nil { @@ -54,6 +173,215 @@ func InitDB(cfg *config.Config) error { return nil } +// startCacheJobs 启动定期缓存任务 +func startCacheJobs() { + // 每5分钟执行一次 + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + // 立即执行一次 + cacheCommonData() + + for { + <-ticker.C + cacheCommonData() + } +} + +// cacheCommonData 缓存常用数据 +func cacheCommonData() { + log.Println("开始自动缓存常用数据...") + + // 缓存所有模型类型 + cacheModelTypes() + + // 缓存所有提供商 + cacheProviders() + + // 缓存价格倍率 + cachePriceRates() + + log.Println("自动缓存常用数据完成") +} + +// cacheModelTypes 缓存所有模型类型 +func cacheModelTypes() { + var types []models.ModelType + if err := DB.Order("sort_order ASC, type_key ASC").Find(&types).Error; err != nil { + log.Printf("缓存模型类型失败: %v", err) + return + } + + GlobalCache.Set("model_types", types, 30*time.Minute) + log.Printf("已缓存 %d 个模型类型", len(types)) +} + +// cacheProviders 缓存所有提供商 +func cacheProviders() { + var providers []models.Provider + if err := DB.Order("id").Find(&providers).Error; err != nil { + log.Printf("缓存提供商失败: %v", err) + return + } + + GlobalCache.Set("providers", providers, 30*time.Minute) + log.Printf("已缓存 %d 个提供商", len(providers)) +} + +// cachePriceRates 缓存价格倍率 +func cachePriceRates() { + // 获取所有已批准的价格 + var prices []models.Price + if err := DB.Where("status = 'approved'").Find(&prices).Error; err != nil { + log.Printf("缓存价格倍率失败: %v", err) + return + } + + // 按模型分组 + modelMap := make(map[string]map[uint]models.Price) + for _, price := range prices { + if _, exists := modelMap[price.Model]; !exists { + modelMap[price.Model] = make(map[uint]models.Price) + } + modelMap[price.Model][price.ChannelType] = price + } + + // 计算倍率 + type PriceRate struct { + Model string `json:"model"` + ModelType string `json:"model_type"` + Type string `json:"type"` + ChannelType uint `json:"channel_type"` + Input float64 `json:"input"` + Output float64 `json:"output"` + } + + var rates []PriceRate + for model, providers := range modelMap { + // 找出基准价格(通常是OpenAI的价格) + var basePrice models.Price + var found bool + for _, price := range providers { + if price.ChannelType == 1 { // 假设OpenAI的ID是1 + basePrice = price + found = true + break + } + } + + if !found { + continue + } + + // 计算其他厂商相对于基准价格的倍率 + for channelType, price := range providers { + if channelType == 1 { + continue // 跳过基准价格 + } + + // 计算输入和输出的倍率 + inputRate := 0.0 + if basePrice.InputPrice > 0 { + inputRate = price.InputPrice / basePrice.InputPrice + } + + outputRate := 0.0 + if basePrice.OutputPrice > 0 { + outputRate = price.OutputPrice / basePrice.OutputPrice + } + + rates = append(rates, PriceRate{ + Model: model, + ModelType: price.ModelType, + Type: price.BillingType, + ChannelType: channelType, + Input: inputRate, + Output: outputRate, + }) + } + } + + GlobalCache.Set("price_rates", rates, 10*time.Minute) + log.Printf("已缓存 %d 个价格倍率", len(rates)) + + // 缓存常用的价格查询 + cachePriceQueries() +} + +// cachePriceQueries 缓存常用的价格查询 +func cachePriceQueries() { + // 缓存第一页数据(无筛选条件) + cachePricePage(1, 20, "", "") + + // 获取所有模型类型 + var modelTypes []models.ModelType + if err := DB.Find(&modelTypes).Error; err != nil { + log.Printf("获取模型类型失败: %v", err) + return + } + + // 获取所有提供商 + var providers []models.Provider + if err := DB.Find(&providers).Error; err != nil { + log.Printf("获取提供商失败: %v", err) + return + } + + // 为每种模型类型缓存第一页数据 + for _, mt := range modelTypes { + cachePricePage(1, 20, "", mt.TypeKey) + } + + // 为每个提供商缓存第一页数据 + for _, p := range providers { + channelType := fmt.Sprintf("%d", p.ID) + cachePricePage(1, 20, channelType, "") + } +} + +// cachePricePage 缓存特定页的价格数据 +func cachePricePage(page, pageSize int, channelType, modelType string) { + offset := (page - 1) * pageSize + + // 构建查询 + query := DB.Model(&models.Price{}) + + // 添加筛选条件 + if channelType != "" { + query = query.Where("channel_type = ?", channelType) + } + if modelType != "" { + query = query.Where("model_type = ?", modelType) + } + + // 获取总数 + var total int64 + if err := query.Count(&total).Error; err != nil { + log.Printf("计算价格总数失败: %v", err) + return + } + + // 获取分页数据 + var prices []models.Price + if err := query.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&prices).Error; err != nil { + log.Printf("获取价格数据失败: %v", err) + return + } + + result := map[string]interface{}{ + "total": total, + "data": prices, + } + + // 构建缓存键 + cacheKey := fmt.Sprintf("prices_page_%d_size_%d_channel_%s_type_%s", + page, pageSize, channelType, modelType) + + // 存入缓存,有效期5分钟 + GlobalCache.Set(cacheKey, result, 5*time.Minute) + log.Printf("已缓存价格查询: %s", cacheKey) +} + // migrateModels 自动迁移模型到数据库表 func migrateModels() error { // 自动迁移模型 @@ -68,8 +396,5 @@ func migrateModels() error { return err } - // 这里可以添加其他模型的迁移 - // 例如:DB.AutoMigrate(&models.User{}) - return nil } diff --git a/backend/handlers/model_type.go b/backend/handlers/model_type.go index 4438d0a..20e713d 100644 --- a/backend/handlers/model_type.go +++ b/backend/handlers/model_type.go @@ -2,6 +2,7 @@ package handlers import ( "net/http" + "time" "github.com/gin-gonic/gin" @@ -11,6 +12,16 @@ import ( // GetModelTypes 获取所有模型类型 func GetModelTypes(c *gin.Context) { + cacheKey := "model_types" + + // 尝试从缓存获取 + if cachedData, found := database.GlobalCache.Get(cacheKey); found { + if types, ok := cachedData.([]models.ModelType); ok { + c.JSON(http.StatusOK, types) + return + } + } + var types []models.ModelType // 使用GORM查询所有模型类型,按排序字段和键值排序 @@ -19,6 +30,9 @@ func GetModelTypes(c *gin.Context) { return } + // 存入缓存,有效期30分钟 + database.GlobalCache.Set(cacheKey, types, 30*time.Minute) + c.JSON(http.StatusOK, types) } @@ -36,6 +50,9 @@ func CreateModelType(c *gin.Context) { return } + // 清除缓存 + database.GlobalCache.Delete("model_types") + c.JSON(http.StatusCreated, newType) } @@ -55,44 +72,19 @@ func UpdateModelType(c *gin.Context) { return } - // 如果key发生变化,需要删除旧记录并创建新记录 - if typeKey != updateType.TypeKey { - // 开始事务 - tx := database.DB.Begin() + // 更新记录 + existingType.TypeLabel = updateType.TypeLabel + existingType.SortOrder = updateType.SortOrder - // 删除旧记录 - if err := tx.Delete(&existingType).Error; err != nil { - tx.Rollback() - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old model type"}) - return - } - - // 创建新记录 - if err := tx.Create(&updateType).Error; err != nil { - tx.Rollback() - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new model type"}) - return - } - - // 提交事务 - if err := tx.Commit().Error; err != nil { - tx.Rollback() - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) - return - } - } else { - // 直接更新,只更新特定字段 - if err := database.DB.Model(&existingType).Updates(map[string]interface{}{ - "type_label": updateType.TypeLabel, - "sort_order": updateType.SortOrder, - }).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"}) - return - } - updateType = existingType + if err := database.DB.Save(&existingType).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - c.JSON(http.StatusOK, updateType) + // 清除缓存 + database.GlobalCache.Delete("model_types") + + c.JSON(http.StatusOK, existingType) } // DeleteModelType 删除模型类型 @@ -120,9 +112,12 @@ func DeleteModelType(c *gin.Context) { // 删除记录 if err := database.DB.Delete(&existingType).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete model type"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + // 清除缓存 + database.GlobalCache.Delete("model_types") + c.JSON(http.StatusOK, gin.H{"message": "Model type deleted successfully"}) } diff --git a/backend/handlers/prices.go b/backend/handlers/prices.go index 0ef4f9a..bab70e4 100644 --- a/backend/handlers/prices.go +++ b/backend/handlers/prices.go @@ -1,6 +1,7 @@ package handlers import ( + "fmt" "net/http" "strconv" "time" @@ -27,8 +28,20 @@ func GetPrices(c *gin.Context) { offset := (page - 1) * pageSize - // 构建查询 - query := database.DB.Model(&models.Price{}) + // 构建缓存键 + cacheKey := fmt.Sprintf("prices_page_%d_size_%d_channel_%s_type_%s", + page, pageSize, channelType, modelType) + + // 尝试从缓存获取 + if cachedData, found := database.GlobalCache.Get(cacheKey); found { + if result, ok := cachedData.(gin.H); ok { + c.JSON(http.StatusOK, result) + return + } + } + + // 构建查询 - 使用索引优化 + query := database.DB.Model(&models.Price{}).Select("*") // 添加筛选条件 if channelType != "" { @@ -38,24 +51,44 @@ func GetPrices(c *gin.Context) { query = query.Where("model_type = ?", modelType) } - // 获取总数 + // 获取总数 - 使用缓存优化 var total int64 - if err := query.Count(&total).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"}) - return + totalCacheKey := fmt.Sprintf("prices_count_channel_%s_type_%s", channelType, modelType) + + if cachedTotal, found := database.GlobalCache.Get(totalCacheKey); found { + if t, ok := cachedTotal.(int64); ok { + total = t + } else { + if err := query.Count(&total).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"}) + return + } + database.GlobalCache.Set(totalCacheKey, total, 5*time.Minute) + } + } else { + if err := query.Count(&total).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"}) + return + } + database.GlobalCache.Set(totalCacheKey, total, 5*time.Minute) } - // 获取分页数据 + // 获取分页数据 - 使用索引优化 var prices []models.Price if err := query.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&prices).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"}) return } - c.JSON(http.StatusOK, gin.H{ + result := gin.H{ "total": total, "data": prices, - }) + } + + // 存入缓存,有效期5分钟 + database.GlobalCache.Set(cacheKey, result, 5*time.Minute) + + c.JSON(http.StatusOK, result) } func CreatePrice(c *gin.Context) { @@ -93,6 +126,9 @@ func CreatePrice(c *gin.Context) { return } + // 清除所有价格相关缓存 + clearPriceCache() + c.JSON(http.StatusCreated, price) } @@ -194,6 +230,9 @@ func UpdatePriceStatus(c *gin.Context) { return } + // 清除所有价格相关缓存 + clearPriceCache() + c.JSON(http.StatusOK, gin.H{ "message": "Status updated successfully", "status": input.Status, @@ -288,6 +327,9 @@ func UpdatePrice(c *gin.Context) { } } + // 清除所有价格相关缓存 + clearPriceCache() + c.JSON(http.StatusOK, existingPrice) } @@ -307,6 +349,9 @@ func DeletePrice(c *gin.Context) { return } + // 清除所有价格相关缓存 + clearPriceCache() + c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"}) } @@ -322,33 +367,45 @@ type PriceRate struct { // GetPriceRates 获取价格倍率 func GetPriceRates(c *gin.Context) { + cacheKey := "price_rates" + + // 尝试从缓存获取 + if cachedData, found := database.GlobalCache.Get(cacheKey); found { + if rates, ok := cachedData.([]PriceRate); ok { + c.JSON(http.StatusOK, rates) + return + } + } + + // 使用索引优化查询,只查询需要的字段 var prices []models.Price - if err := database.DB.Where("status = 'approved'").Find(&prices).Error; err != nil { + if err := database.DB.Select("model, model_type, billing_type, channel_type, input_price, output_price"). + Where("status = 'approved'"). + Find(&prices).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"}) return } - // 按模型分组 - modelMap := make(map[string]map[uint]models.Price) + // 按模型分组 - 使用map优化 + modelMap := make(map[string]map[uint]models.Price, len(prices)/2) // 预分配合理大小 for _, price := range prices { if _, exists := modelMap[price.Model]; !exists { - modelMap[price.Model] = make(map[uint]models.Price) + modelMap[price.Model] = make(map[uint]models.Price, 5) // 假设每个模型有5个提供商 } modelMap[price.Model][price.ChannelType] = price } + // 预分配rates切片,减少内存分配 + rates := make([]PriceRate, 0, len(prices)) + // 计算倍率 - var rates []PriceRate for model, providers := range modelMap { // 找出基准价格(通常是OpenAI的价格) var basePrice models.Price var found bool - for _, price := range providers { - if price.ChannelType == 1 { // 假设OpenAI的ID是1 - basePrice = price - found = true - break - } + if baseProvider, exists := providers[1]; exists { // 直接检查ID为1的提供商 + basePrice = baseProvider + found = true } if !found { @@ -383,6 +440,9 @@ func GetPriceRates(c *gin.Context) { } } + // 存入缓存,有效期10分钟 + database.GlobalCache.Set(cacheKey, rates, 10*time.Minute) + c.JSON(http.StatusOK, rates) } @@ -454,8 +514,17 @@ func ApproveAllPrices(c *gin.Context) { return } + // 清除所有价格相关缓存 + clearPriceCache() + c.JSON(http.StatusOK, gin.H{ "message": "All pending prices approved successfully", "count": len(pendingPrices), }) } + +// clearPriceCache 清除所有价格相关的缓存 +func clearPriceCache() { + // 由于我们无法精确知道哪些缓存键与价格相关,所以清除所有缓存 + database.GlobalCache.Clear() +} diff --git a/backend/handlers/providers.go b/backend/handlers/providers.go index 9b2d420..65c5158 100644 --- a/backend/handlers/providers.go +++ b/backend/handlers/providers.go @@ -13,6 +13,16 @@ import ( // GetProviders 获取所有模型厂商 func GetProviders(c *gin.Context) { + cacheKey := "providers" + + // 尝试从缓存获取 + if cachedData, found := database.GlobalCache.Get(cacheKey); found { + if providers, ok := cachedData.([]models.Provider); ok { + c.JSON(http.StatusOK, providers) + return + } + } + var providers []models.Provider if err := database.DB.Order("id").Find(&providers).Error; err != nil { @@ -20,6 +30,9 @@ func GetProviders(c *gin.Context) { return } + // 存入缓存,有效期30分钟 + database.GlobalCache.Set(cacheKey, providers, 30*time.Minute) + c.JSON(http.StatusOK, providers) } @@ -56,6 +69,9 @@ func CreateProvider(c *gin.Context) { return } + // 清除缓存 + database.GlobalCache.Delete("providers") + c.JSON(http.StatusCreated, provider) } @@ -127,6 +143,9 @@ func UpdateProvider(c *gin.Context) { provider = existingProvider } + // 清除缓存 + database.GlobalCache.Delete("providers") + c.JSON(http.StatusOK, provider) } @@ -212,5 +231,8 @@ func DeleteProvider(c *gin.Context) { return } + // 清除缓存 + database.GlobalCache.Delete("providers") + c.JSON(http.StatusOK, gin.H{"message": "Provider deleted successfully"}) } diff --git a/backend/models/price.go b/backend/models/price.go index 6c6fddb..930bb5c 100644 --- a/backend/models/price.go +++ b/backend/models/price.go @@ -8,16 +8,16 @@ import ( type Price struct { ID uint `json:"id" gorm:"primaryKey"` - Model string `json:"model" gorm:"not null"` - ModelType string `json:"model_type" gorm:"not null"` // text2text, text2image, etc. - BillingType string `json:"billing_type" gorm:"not null"` // tokens or times - ChannelType uint `json:"channel_type" gorm:"not null"` + Model string `json:"model" gorm:"not null;index:idx_model_channel"` + ModelType string `json:"model_type" gorm:"not null;index:idx_model_type"` // text2text, text2image, etc. + BillingType string `json:"billing_type" gorm:"not null"` // tokens or times + ChannelType uint `json:"channel_type" gorm:"not null;index:idx_model_channel"` Currency string `json:"currency" gorm:"not null"` // USD or CNY InputPrice float64 `json:"input_price" gorm:"not null"` OutputPrice float64 `json:"output_price" gorm:"not null"` PriceSource string `json:"price_source" gorm:"not null"` - Status string `json:"status" gorm:"not null;default:pending"` // pending, approved, rejected - CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + Status string `json:"status" gorm:"not null;default:pending;index:idx_status"` // pending, approved, rejected + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_created_at"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` CreatedBy string `json:"created_by" gorm:"not null"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` diff --git a/frontend/src/views/ModelTypes.vue b/frontend/src/views/ModelTypes.vue index 669623d..1f57bc6 100644 --- a/frontend/src/views/ModelTypes.vue +++ b/frontend/src/views/ModelTypes.vue @@ -1,24 +1,25 @@