mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
401 lines
8.6 KiB
Go
401 lines
8.6 KiB
Go
package database
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"sync"
|
||
"time"
|
||
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
|
||
"aimodels-prices/config"
|
||
"aimodels-prices/models"
|
||
)
|
||
|
||
// DB 是数据库连接的全局实例
|
||
var DB *gorm.DB
|
||
|
||
// Cache 接口定义了缓存的基本操作
|
||
type Cache interface {
|
||
Get(key string) (interface{}, bool)
|
||
Set(key string, value interface{}, expiration time.Duration)
|
||
Delete(key string)
|
||
Clear()
|
||
}
|
||
|
||
// MemoryCache 是一个简单的内存缓存实现
|
||
type MemoryCache struct {
|
||
items map[string]cacheItem
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
type cacheItem struct {
|
||
value interface{}
|
||
expiration int64
|
||
}
|
||
|
||
// 全局缓存实例
|
||
var GlobalCache Cache
|
||
|
||
// NewMemoryCache 创建一个新的内存缓存
|
||
func NewMemoryCache() *MemoryCache {
|
||
cache := &MemoryCache{
|
||
items: make(map[string]cacheItem),
|
||
}
|
||
|
||
// 启动一个后台协程定期清理过期项
|
||
go cache.janitor()
|
||
|
||
return cache
|
||
}
|
||
|
||
// Get 从缓存中获取值
|
||
func (c *MemoryCache) Get(key string) (interface{}, bool) {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
item, found := c.items[key]
|
||
if !found {
|
||
return nil, false
|
||
}
|
||
|
||
// 检查是否过期
|
||
if item.expiration > 0 && item.expiration < time.Now().UnixNano() {
|
||
return nil, false
|
||
}
|
||
|
||
return item.value, true
|
||
}
|
||
|
||
// Set 设置缓存值
|
||
func (c *MemoryCache) Set(key string, value interface{}, expiration time.Duration) {
|
||
var exp int64
|
||
|
||
if expiration > 0 {
|
||
exp = time.Now().Add(expiration).UnixNano()
|
||
}
|
||
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
c.items[key] = cacheItem{
|
||
value: value,
|
||
expiration: exp,
|
||
}
|
||
}
|
||
|
||
// Delete 删除缓存项
|
||
func (c *MemoryCache) Delete(key string) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
delete(c.items, key)
|
||
}
|
||
|
||
// Clear 清空所有缓存
|
||
func (c *MemoryCache) Clear() {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
c.items = make(map[string]cacheItem)
|
||
}
|
||
|
||
// janitor 定期清理过期的缓存项
|
||
func (c *MemoryCache) janitor() {
|
||
ticker := time.NewTicker(time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
<-ticker.C
|
||
c.deleteExpired()
|
||
}
|
||
}
|
||
|
||
// deleteExpired 删除所有过期的项
|
||
func (c *MemoryCache) deleteExpired() {
|
||
now := time.Now().UnixNano()
|
||
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
for k, v := range c.items {
|
||
if v.expiration > 0 && v.expiration < now {
|
||
delete(c.items, k)
|
||
}
|
||
}
|
||
}
|
||
|
||
// InitDB 初始化数据库连接
|
||
func InitDB(cfg *config.Config) error {
|
||
var err error
|
||
|
||
// 构建MySQL DSN
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||
cfg.DBUser,
|
||
cfg.DBPassword,
|
||
cfg.DBHost,
|
||
cfg.DBPort,
|
||
cfg.DBName,
|
||
)
|
||
|
||
// 连接MySQL
|
||
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
Logger: logger.Default.LogMode(logger.Error),
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("failed to connect to MySQL: %v", err)
|
||
}
|
||
|
||
// 获取底层的SQL DB
|
||
sqlDB, err := DB.DB()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get underlying SQL DB: %v", err)
|
||
}
|
||
|
||
// 设置连接池参数
|
||
sqlDB.SetMaxOpenConns(20) // 增加最大连接数
|
||
sqlDB.SetMaxIdleConns(10) // 增加空闲连接数
|
||
sqlDB.SetConnMaxLifetime(time.Hour) // 设置连接最大生命周期
|
||
|
||
// 初始化缓存
|
||
GlobalCache = NewMemoryCache()
|
||
|
||
// 启动定期缓存任务
|
||
go startCacheJobs()
|
||
|
||
// 自动迁移表结构
|
||
if err = migrateModels(); err != nil {
|
||
return fmt.Errorf("failed to migrate models: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// startCacheJobs 启动定期缓存任务
|
||
func startCacheJobs() {
|
||
// 每5分钟执行一次
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
// 立即执行一次
|
||
cacheCommonData()
|
||
|
||
for {
|
||
<-ticker.C
|
||
cacheCommonData()
|
||
}
|
||
}
|
||
|
||
// cacheCommonData 缓存常用数据
|
||
func cacheCommonData() {
|
||
log.Println("开始自动缓存常用数据...")
|
||
|
||
// 缓存所有模型类型
|
||
cacheModelTypes()
|
||
|
||
// 缓存所有提供商
|
||
cacheProviders()
|
||
|
||
// 缓存价格倍率
|
||
cachePriceRates()
|
||
|
||
log.Println("自动缓存常用数据完成")
|
||
}
|
||
|
||
// cacheModelTypes 缓存所有模型类型
|
||
func cacheModelTypes() {
|
||
var types []models.ModelType
|
||
if err := DB.Order("sort_order ASC, type_key ASC").Find(&types).Error; err != nil {
|
||
log.Printf("缓存模型类型失败: %v", err)
|
||
return
|
||
}
|
||
|
||
GlobalCache.Set("model_types", types, 30*time.Minute)
|
||
log.Printf("已缓存 %d 个模型类型", len(types))
|
||
}
|
||
|
||
// cacheProviders 缓存所有提供商
|
||
func cacheProviders() {
|
||
var providers []models.Provider
|
||
if err := DB.Order("id").Find(&providers).Error; err != nil {
|
||
log.Printf("缓存提供商失败: %v", err)
|
||
return
|
||
}
|
||
|
||
GlobalCache.Set("providers", providers, 30*time.Minute)
|
||
log.Printf("已缓存 %d 个提供商", len(providers))
|
||
}
|
||
|
||
// cachePriceRates 缓存价格倍率
|
||
func cachePriceRates() {
|
||
// 获取所有已批准的价格
|
||
var prices []models.Price
|
||
if err := DB.Where("status = 'approved'").Find(&prices).Error; err != nil {
|
||
log.Printf("缓存价格倍率失败: %v", err)
|
||
return
|
||
}
|
||
|
||
// 按模型分组
|
||
modelMap := make(map[string]map[uint]models.Price)
|
||
for _, price := range prices {
|
||
if _, exists := modelMap[price.Model]; !exists {
|
||
modelMap[price.Model] = make(map[uint]models.Price)
|
||
}
|
||
modelMap[price.Model][price.ChannelType] = price
|
||
}
|
||
|
||
// 计算倍率
|
||
type PriceRate struct {
|
||
Model string `json:"model"`
|
||
ModelType string `json:"model_type"`
|
||
Type string `json:"type"`
|
||
ChannelType uint `json:"channel_type"`
|
||
Input float64 `json:"input"`
|
||
Output float64 `json:"output"`
|
||
}
|
||
|
||
var rates []PriceRate
|
||
for model, providers := range modelMap {
|
||
// 找出基准价格(通常是OpenAI的价格)
|
||
var basePrice models.Price
|
||
var found bool
|
||
for _, price := range providers {
|
||
if price.ChannelType == 1 { // 假设OpenAI的ID是1
|
||
basePrice = price
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !found {
|
||
continue
|
||
}
|
||
|
||
// 计算其他厂商相对于基准价格的倍率
|
||
for channelType, price := range providers {
|
||
if channelType == 1 {
|
||
continue // 跳过基准价格
|
||
}
|
||
|
||
// 计算输入和输出的倍率
|
||
inputRate := 0.0
|
||
if basePrice.InputPrice > 0 {
|
||
inputRate = price.InputPrice / basePrice.InputPrice
|
||
}
|
||
|
||
outputRate := 0.0
|
||
if basePrice.OutputPrice > 0 {
|
||
outputRate = price.OutputPrice / basePrice.OutputPrice
|
||
}
|
||
|
||
rates = append(rates, PriceRate{
|
||
Model: model,
|
||
ModelType: price.ModelType,
|
||
Type: price.BillingType,
|
||
ChannelType: channelType,
|
||
Input: inputRate,
|
||
Output: outputRate,
|
||
})
|
||
}
|
||
}
|
||
|
||
GlobalCache.Set("price_rates", rates, 10*time.Minute)
|
||
log.Printf("已缓存 %d 个价格倍率", len(rates))
|
||
|
||
// 缓存常用的价格查询
|
||
cachePriceQueries()
|
||
}
|
||
|
||
// cachePriceQueries 缓存常用的价格查询
|
||
func cachePriceQueries() {
|
||
// 缓存第一页数据(无筛选条件)
|
||
cachePricePage(1, 20, "", "")
|
||
|
||
// 获取所有模型类型
|
||
var modelTypes []models.ModelType
|
||
if err := DB.Find(&modelTypes).Error; err != nil {
|
||
log.Printf("获取模型类型失败: %v", err)
|
||
return
|
||
}
|
||
|
||
// 获取所有提供商
|
||
var providers []models.Provider
|
||
if err := DB.Find(&providers).Error; err != nil {
|
||
log.Printf("获取提供商失败: %v", err)
|
||
return
|
||
}
|
||
|
||
// 为每种模型类型缓存第一页数据
|
||
for _, mt := range modelTypes {
|
||
cachePricePage(1, 20, "", mt.TypeKey)
|
||
}
|
||
|
||
// 为每个提供商缓存第一页数据
|
||
for _, p := range providers {
|
||
channelType := fmt.Sprintf("%d", p.ID)
|
||
cachePricePage(1, 20, channelType, "")
|
||
}
|
||
}
|
||
|
||
// cachePricePage 缓存特定页的价格数据
|
||
func cachePricePage(page, pageSize int, channelType, modelType string) {
|
||
offset := (page - 1) * pageSize
|
||
|
||
// 构建查询
|
||
query := DB.Model(&models.Price{})
|
||
|
||
// 添加筛选条件
|
||
if channelType != "" {
|
||
query = query.Where("channel_type = ?", channelType)
|
||
}
|
||
if modelType != "" {
|
||
query = query.Where("model_type = ?", modelType)
|
||
}
|
||
|
||
// 获取总数
|
||
var total int64
|
||
if err := query.Count(&total).Error; err != nil {
|
||
log.Printf("计算价格总数失败: %v", err)
|
||
return
|
||
}
|
||
|
||
// 获取分页数据
|
||
var prices []models.Price
|
||
if err := query.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&prices).Error; err != nil {
|
||
log.Printf("获取价格数据失败: %v", err)
|
||
return
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"total": total,
|
||
"data": prices,
|
||
}
|
||
|
||
// 构建缓存键
|
||
cacheKey := fmt.Sprintf("prices_page_%d_size_%d_channel_%s_type_%s",
|
||
page, pageSize, channelType, modelType)
|
||
|
||
// 存入缓存,有效期5分钟
|
||
GlobalCache.Set(cacheKey, result, 5*time.Minute)
|
||
log.Printf("已缓存价格查询: %s", cacheKey)
|
||
}
|
||
|
||
// migrateModels 自动迁移模型到数据库表
|
||
func migrateModels() error {
|
||
// 自动迁移模型
|
||
if err := DB.AutoMigrate(
|
||
&models.ModelType{},
|
||
&models.Price{},
|
||
&models.Provider{},
|
||
&models.User{},
|
||
&models.Session{},
|
||
); err != nil {
|
||
log.Printf("Failed to migrate tables: %v", err)
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|