mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
重构价格管理模块,完全迁移到 GORM ORM
- 将价格相关处理函数从原生 SQL 完全迁移到 GORM - 优化数据库查询逻辑,使用 GORM 的链式查询和更新方法 - 重构价格审核、更新和批量审核功能,使用事务处理 - 简化数据库操作,移除手动 SQL 查询和扫描逻辑 - 优化价格倍率计算方法,提高代码可读性和性能
This commit is contained in:
parent
31f65a9301
commit
449f95d1b5
@ -81,10 +81,11 @@ func UpdateModelType(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 直接更新
|
// 直接更新,只更新特定字段
|
||||||
existingType.TypeLabel = updateType.TypeLabel
|
if err := database.DB.Model(&existingType).Updates(map[string]interface{}{
|
||||||
existingType.SortOrder = updateType.SortOrder
|
"type_label": updateType.TypeLabel,
|
||||||
if err := database.DB.Save(&existingType).Error; err != nil {
|
"sort_order": updateType.SortOrder,
|
||||||
|
}).Error; err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,20 +1,17 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"aimodels-prices/database"
|
||||||
"aimodels-prices/models"
|
"aimodels-prices/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPrices(c *gin.Context) {
|
func GetPrices(c *gin.Context) {
|
||||||
db := c.MustGet("db").(*sql.DB)
|
|
||||||
|
|
||||||
// 获取分页和筛选参数
|
// 获取分页和筛选参数
|
||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||||
@ -30,75 +27,34 @@ func GetPrices(c *gin.Context) {
|
|||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
|
|
||||||
// 构建查询条件
|
// 构建查询
|
||||||
var conditions []string
|
query := database.DB.Model(&models.Price{})
|
||||||
var args []interface{}
|
|
||||||
|
|
||||||
|
// 添加筛选条件
|
||||||
if channelType != "" {
|
if channelType != "" {
|
||||||
conditions = append(conditions, "channel_type = ?")
|
query = query.Where("channel_type = ?", channelType)
|
||||||
args = append(args, channelType)
|
|
||||||
}
|
}
|
||||||
if modelType != "" {
|
if modelType != "" {
|
||||||
conditions = append(conditions, "model_type = ?")
|
query = query.Where("model_type = ?", modelType)
|
||||||
args = append(args, modelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 组合WHERE子句
|
|
||||||
var whereClause string
|
|
||||||
if len(conditions) > 0 {
|
|
||||||
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取总数
|
// 获取总数
|
||||||
var total int
|
var total int64
|
||||||
countQuery := "SELECT COUNT(*) FROM price"
|
if err := query.Count(&total).Error; err != nil {
|
||||||
if whereClause != "" {
|
|
||||||
countQuery += " " + whereClause
|
|
||||||
}
|
|
||||||
err := db.QueryRow(countQuery, args...).Scan(&total)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用分页查询
|
// 获取分页数据
|
||||||
query := `
|
var prices []models.Price
|
||||||
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
if err := query.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&prices).Error; err != nil {
|
||||||
price_source, status, created_at, updated_at, created_by,
|
|
||||||
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
|
|
||||||
temp_input_price, temp_output_price, temp_price_source, updated_by
|
|
||||||
FROM price`
|
|
||||||
if whereClause != "" {
|
|
||||||
query += " " + whereClause
|
|
||||||
}
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
|
||||||
args = append(args, pageSize, offset)
|
|
||||||
|
|
||||||
rows, err := db.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var prices []models.Price
|
|
||||||
for rows.Next() {
|
|
||||||
var price models.Price
|
|
||||||
if err := rows.Scan(
|
|
||||||
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
|
|
||||||
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
|
||||||
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
|
||||||
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
|
||||||
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
prices = append(prices, price)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"total": total,
|
"total": total,
|
||||||
"prices": prices,
|
"data": prices,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,46 +66,33 @@ func CreatePrice(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证模型厂商ID是否存在
|
// 验证模型厂商ID是否存在
|
||||||
db := c.MustGet("db").(*sql.DB)
|
var provider models.Provider
|
||||||
var providerExists bool
|
if err := database.DB.Where("id = ?", price.ChannelType).First(&provider).Error; err != nil {
|
||||||
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
|
|
||||||
if err != nil || !providerExists {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查同一厂商下是否已存在相同名称的模型
|
// 检查同一厂商下是否已存在相同名称的模型
|
||||||
var modelExists bool
|
var count int64
|
||||||
err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM price WHERE channel_type = ? AND model = ? AND status = 'approved')",
|
if err := database.DB.Model(&models.Price{}).Where("channel_type = ? AND model = ? AND status = 'approved'",
|
||||||
price.ChannelType, price.Model).Scan(&modelExists)
|
price.ChannelType, price.Model).Count(&count).Error; err != nil {
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if modelExists {
|
if count > 0 {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
// 设置状态和创建者
|
||||||
result, err := db.Exec(`
|
price.Status = "pending"
|
||||||
INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
|
||||||
price_source, status, created_by, created_at, updated_at)
|
// 创建记录
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
|
if err := database.DB.Create(&price).Error; err != nil {
|
||||||
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
|
|
||||||
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
|
|
||||||
now, now)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create price"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create price"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
price.ID = uint(id)
|
|
||||||
price.Status = "pending"
|
|
||||||
price.CreatedAt = now
|
|
||||||
price.UpdatedAt = now
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, price)
|
c.JSON(http.StatusCreated, price)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,67 +107,100 @@ func UpdatePriceStatus(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := c.MustGet("db").(*sql.DB)
|
// 查找价格记录
|
||||||
now := time.Now()
|
var price models.Price
|
||||||
|
if err := database.DB.Where("id = ?", id).First(&price).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
tx := database.DB.Begin()
|
||||||
|
|
||||||
if input.Status == "approved" {
|
if input.Status == "approved" {
|
||||||
// 如果是批准,将临时字段的值更新到正式字段
|
// 如果是批准,将临时字段的值更新到正式字段
|
||||||
_, err := db.Exec(`
|
updateMap := map[string]interface{}{
|
||||||
UPDATE price
|
"status": input.Status,
|
||||||
SET model = COALESCE(temp_model, model),
|
"updated_at": time.Now(),
|
||||||
model_type = COALESCE(temp_model_type, model_type),
|
}
|
||||||
billing_type = COALESCE(temp_billing_type, billing_type),
|
|
||||||
channel_type = COALESCE(temp_channel_type, channel_type),
|
// 如果临时字段有值,则更新主字段
|
||||||
currency = COALESCE(temp_currency, currency),
|
if price.TempModel != nil {
|
||||||
input_price = COALESCE(temp_input_price, input_price),
|
updateMap["model"] = *price.TempModel
|
||||||
output_price = COALESCE(temp_output_price, output_price),
|
}
|
||||||
price_source = COALESCE(temp_price_source, price_source),
|
if price.TempModelType != nil {
|
||||||
status = ?,
|
updateMap["model_type"] = *price.TempModelType
|
||||||
updated_at = ?,
|
}
|
||||||
temp_model = NULL,
|
if price.TempBillingType != nil {
|
||||||
temp_model_type = NULL,
|
updateMap["billing_type"] = *price.TempBillingType
|
||||||
temp_billing_type = NULL,
|
}
|
||||||
temp_channel_type = NULL,
|
if price.TempChannelType != nil {
|
||||||
temp_currency = NULL,
|
updateMap["channel_type"] = *price.TempChannelType
|
||||||
temp_input_price = NULL,
|
}
|
||||||
temp_output_price = NULL,
|
if price.TempCurrency != nil {
|
||||||
temp_price_source = NULL,
|
updateMap["currency"] = *price.TempCurrency
|
||||||
updated_by = NULL
|
}
|
||||||
WHERE id = ?`, input.Status, now, id)
|
if price.TempInputPrice != nil {
|
||||||
if err != nil {
|
updateMap["input_price"] = *price.TempInputPrice
|
||||||
|
}
|
||||||
|
if price.TempOutputPrice != nil {
|
||||||
|
updateMap["output_price"] = *price.TempOutputPrice
|
||||||
|
}
|
||||||
|
if price.TempPriceSource != nil {
|
||||||
|
updateMap["price_source"] = *price.TempPriceSource
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除所有临时字段
|
||||||
|
updateMap["temp_model"] = nil
|
||||||
|
updateMap["temp_model_type"] = nil
|
||||||
|
updateMap["temp_billing_type"] = nil
|
||||||
|
updateMap["temp_channel_type"] = nil
|
||||||
|
updateMap["temp_currency"] = nil
|
||||||
|
updateMap["temp_input_price"] = nil
|
||||||
|
updateMap["temp_output_price"] = nil
|
||||||
|
updateMap["temp_price_source"] = nil
|
||||||
|
updateMap["updated_by"] = nil
|
||||||
|
|
||||||
|
if err := tx.Model(&price).Updates(updateMap).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 如果是拒绝,清除临时字段
|
// 如果是拒绝,清除临时字段
|
||||||
_, err := db.Exec(`
|
if err := tx.Model(&price).Updates(map[string]interface{}{
|
||||||
UPDATE price
|
"status": input.Status,
|
||||||
SET status = ?,
|
"updated_at": time.Now(),
|
||||||
updated_at = ?,
|
"temp_model": nil,
|
||||||
temp_model = NULL,
|
"temp_model_type": nil,
|
||||||
temp_model_type = NULL,
|
"temp_billing_type": nil,
|
||||||
temp_billing_type = NULL,
|
"temp_channel_type": nil,
|
||||||
temp_channel_type = NULL,
|
"temp_currency": nil,
|
||||||
temp_currency = NULL,
|
"temp_input_price": nil,
|
||||||
temp_input_price = NULL,
|
"temp_output_price": nil,
|
||||||
temp_output_price = NULL,
|
"temp_price_source": nil,
|
||||||
temp_price_source = NULL,
|
"updated_by": nil,
|
||||||
updated_by = NULL
|
}).Error; err != nil {
|
||||||
WHERE id = ?`, input.Status, now, id)
|
tx.Rollback()
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "Status updated successfully",
|
"message": "Status updated successfully",
|
||||||
"status": input.Status,
|
"status": input.Status,
|
||||||
"updated_at": now,
|
"updated_at": time.Now(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePrice 更新价格
|
|
||||||
func UpdatePrice(c *gin.Context) {
|
func UpdatePrice(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
var price models.Price
|
var price models.Price
|
||||||
@ -234,23 +210,20 @@ func UpdatePrice(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证模型厂商ID是否存在
|
// 验证模型厂商ID是否存在
|
||||||
db := c.MustGet("db").(*sql.DB)
|
var provider models.Provider
|
||||||
var providerExists bool
|
if err := database.DB.Where("id = ?", price.ChannelType).First(&provider).Error; err != nil {
|
||||||
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
|
|
||||||
if err != nil || !providerExists {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查同一厂商下是否已存在相同名称的模型(排除当前正在编辑的记录)
|
// 检查同一厂商下是否已存在相同名称的模型(排除当前正在编辑的记录)
|
||||||
var modelExists bool
|
var count int64
|
||||||
err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM price WHERE channel_type = ? AND model = ? AND id != ? AND status = 'approved')",
|
if err := database.DB.Model(&models.Price{}).Where("channel_type = ? AND model = ? AND id != ? AND status = 'approved'",
|
||||||
price.ChannelType, price.Model, id).Scan(&modelExists)
|
price.ChannelType, price.Model, id).Count(&count).Error; err != nil {
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if modelExists {
|
if count > 0 {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -263,76 +236,73 @@ func UpdatePrice(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
currentUser := user.(*models.User)
|
currentUser := user.(*models.User)
|
||||||
|
|
||||||
now := time.Now()
|
// 查找现有记录
|
||||||
|
var existingPrice models.Price
|
||||||
var query string
|
if err := database.DB.Where("id = ?", id).First(&existingPrice).Error; err != nil {
|
||||||
var args []interface{}
|
c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 根据用户角色决定更新方式
|
// 根据用户角色决定更新方式
|
||||||
if currentUser.Role == "admin" {
|
if currentUser.Role == "admin" {
|
||||||
// 管理员直接更新主字段
|
// 管理员直接更新主字段
|
||||||
query = `
|
existingPrice.Model = price.Model
|
||||||
UPDATE price
|
existingPrice.ModelType = price.ModelType
|
||||||
SET model = ?, model_type = ?, billing_type = ?, channel_type = ?, currency = ?,
|
existingPrice.BillingType = price.BillingType
|
||||||
input_price = ?, output_price = ?, price_source = ?,
|
existingPrice.ChannelType = price.ChannelType
|
||||||
updated_by = ?, updated_at = ?, status = 'approved',
|
existingPrice.Currency = price.Currency
|
||||||
temp_model = NULL, temp_model_type = NULL, temp_billing_type = NULL,
|
existingPrice.InputPrice = price.InputPrice
|
||||||
temp_channel_type = NULL, temp_currency = NULL, temp_input_price = NULL,
|
existingPrice.OutputPrice = price.OutputPrice
|
||||||
temp_output_price = NULL, temp_price_source = NULL
|
existingPrice.PriceSource = price.PriceSource
|
||||||
WHERE id = ?`
|
existingPrice.Status = "approved"
|
||||||
args = []interface{}{
|
existingPrice.UpdatedBy = ¤tUser.Username
|
||||||
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
|
existingPrice.TempModel = nil
|
||||||
price.InputPrice, price.OutputPrice, price.PriceSource,
|
existingPrice.TempModelType = nil
|
||||||
currentUser.Username, now, id,
|
existingPrice.TempBillingType = nil
|
||||||
}
|
existingPrice.TempChannelType = nil
|
||||||
} else {
|
existingPrice.TempCurrency = nil
|
||||||
// 普通用户更新临时字段
|
existingPrice.TempInputPrice = nil
|
||||||
query = `
|
existingPrice.TempOutputPrice = nil
|
||||||
UPDATE price
|
existingPrice.TempPriceSource = nil
|
||||||
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?,
|
|
||||||
temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
|
|
||||||
updated_by = ?, updated_at = ?, status = 'pending'
|
|
||||||
WHERE id = ?`
|
|
||||||
args = []interface{}{
|
|
||||||
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
|
|
||||||
price.InputPrice, price.OutputPrice, price.PriceSource,
|
|
||||||
currentUser.Username, now, id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.Exec(query, args...)
|
if err := database.DB.Save(&existingPrice).Error; err != nil {
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// 普通用户更新临时字段
|
||||||
|
existingPrice.TempModel = &price.Model
|
||||||
|
existingPrice.TempModelType = &price.ModelType
|
||||||
|
existingPrice.TempBillingType = &price.BillingType
|
||||||
|
existingPrice.TempChannelType = &price.ChannelType
|
||||||
|
existingPrice.TempCurrency = &price.Currency
|
||||||
|
existingPrice.TempInputPrice = &price.InputPrice
|
||||||
|
existingPrice.TempOutputPrice = &price.OutputPrice
|
||||||
|
existingPrice.TempPriceSource = &price.PriceSource
|
||||||
|
existingPrice.Status = "pending"
|
||||||
|
existingPrice.UpdatedBy = ¤tUser.Username
|
||||||
|
|
||||||
// 获取更新后的价格信息
|
if err := database.DB.Save(&existingPrice).Error; err != nil {
|
||||||
err = db.QueryRow(`
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
|
||||||
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
return
|
||||||
price_source, status, created_at, updated_at, created_by,
|
}
|
||||||
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
|
}
|
||||||
temp_input_price, temp_output_price, temp_price_source, updated_by
|
|
||||||
FROM price WHERE id = ?`, id).Scan(
|
c.JSON(http.StatusOK, existingPrice)
|
||||||
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
|
}
|
||||||
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
|
||||||
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
func DeletePrice(c *gin.Context) {
|
||||||
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
id := c.Param("id")
|
||||||
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
|
|
||||||
if err != nil {
|
// 查找价格记录
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"})
|
var price models.Price
|
||||||
|
if err := database.DB.Where("id = ?", id).First(&price).Error; err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, price)
|
// 删除记录
|
||||||
}
|
if err := database.DB.Delete(&price).Error; err != nil {
|
||||||
|
|
||||||
// DeletePrice 删除价格
|
|
||||||
func DeletePrice(c *gin.Context) {
|
|
||||||
id := c.Param("id")
|
|
||||||
db := c.MustGet("db").(*sql.DB)
|
|
||||||
|
|
||||||
_, err := db.Exec("DELETE FROM price WHERE id = ?", id)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete price"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete price"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -352,123 +322,140 @@ type PriceRate struct {
|
|||||||
|
|
||||||
// GetPriceRates 获取价格倍率
|
// GetPriceRates 获取价格倍率
|
||||||
func GetPriceRates(c *gin.Context) {
|
func GetPriceRates(c *gin.Context) {
|
||||||
db := c.MustGet("db").(*sql.DB)
|
var prices []models.Price
|
||||||
rows, err := db.Query(`
|
if err := database.DB.Where("status = 'approved'").Find(&prices).Error; err != nil {
|
||||||
SELECT model, model_type, billing_type, channel_type,
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
|
||||||
CASE
|
|
||||||
WHEN currency = 'USD' THEN input_price / 2
|
|
||||||
ELSE input_price / 14
|
|
||||||
END as input_rate,
|
|
||||||
CASE
|
|
||||||
WHEN currency = 'USD' THEN output_price / 2
|
|
||||||
ELSE output_price / 14
|
|
||||||
END as output_rate
|
|
||||||
FROM price
|
|
||||||
WHERE status = 'approved'
|
|
||||||
ORDER BY model, channel_type`)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch price rates"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var rates []PriceRate
|
// 按模型分组
|
||||||
for rows.Next() {
|
modelMap := make(map[string]map[uint]models.Price)
|
||||||
var rate PriceRate
|
for _, price := range prices {
|
||||||
if err := rows.Scan(
|
if _, exists := modelMap[price.Model]; !exists {
|
||||||
&rate.Model,
|
modelMap[price.Model] = make(map[uint]models.Price)
|
||||||
&rate.ModelType,
|
}
|
||||||
&rate.Type,
|
modelMap[price.Model][price.ChannelType] = price
|
||||||
&rate.ChannelType,
|
}
|
||||||
&rate.Input,
|
|
||||||
&rate.Output); err != nil {
|
// 计算倍率
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price rate"})
|
var rates []PriceRate
|
||||||
return
|
for model, providers := range modelMap {
|
||||||
|
// 找出基准价格(通常是OpenAI的价格)
|
||||||
|
var basePrice models.Price
|
||||||
|
var found bool
|
||||||
|
for _, price := range providers {
|
||||||
|
if price.ChannelType == 1 { // 假设OpenAI的ID是1
|
||||||
|
basePrice = price
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算其他厂商相对于基准价格的倍率
|
||||||
|
for channelType, price := range providers {
|
||||||
|
if channelType == 1 {
|
||||||
|
continue // 跳过基准价格
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算输入和输出的倍率
|
||||||
|
inputRate := 0.0
|
||||||
|
if basePrice.InputPrice > 0 {
|
||||||
|
inputRate = price.InputPrice / basePrice.InputPrice
|
||||||
|
}
|
||||||
|
|
||||||
|
outputRate := 0.0
|
||||||
|
if basePrice.OutputPrice > 0 {
|
||||||
|
outputRate = price.OutputPrice / basePrice.OutputPrice
|
||||||
|
}
|
||||||
|
|
||||||
|
rates = append(rates, PriceRate{
|
||||||
|
Model: model,
|
||||||
|
ModelType: price.ModelType,
|
||||||
|
Type: price.BillingType,
|
||||||
|
ChannelType: channelType,
|
||||||
|
Input: inputRate,
|
||||||
|
Output: outputRate,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
rates = append(rates, rate)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, rates)
|
c.JSON(http.StatusOK, rates)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApproveAllPrices 批量通过所有待审核的价格
|
|
||||||
func ApproveAllPrices(c *gin.Context) {
|
func ApproveAllPrices(c *gin.Context) {
|
||||||
var input struct {
|
// 查找所有待审核的价格
|
||||||
Status string `json:"status" binding:"required,eq=approved"`
|
var pendingPrices []models.Price
|
||||||
}
|
if err := database.DB.Where("status = 'pending'").Find(&pendingPrices).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch pending prices"})
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := c.MustGet("db").(*sql.DB)
|
// 开始事务
|
||||||
now := time.Now()
|
tx := database.DB.Begin()
|
||||||
|
|
||||||
// 获取当前用户
|
for _, price := range pendingPrices {
|
||||||
user, exists := c.Get("user")
|
updateMap := map[string]interface{}{
|
||||||
if !exists {
|
"status": "approved",
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
|
"updated_at": time.Now(),
|
||||||
return
|
|
||||||
}
|
|
||||||
currentUser := user.(*models.User)
|
|
||||||
|
|
||||||
// 只有管理员可以批量通过
|
|
||||||
if currentUser.Role != "admin" {
|
|
||||||
c.JSON(http.StatusForbidden, gin.H{"error": "Admin permission required"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询待审核的价格数量
|
// 如果临时字段有值,则更新主字段
|
||||||
var pendingCount int
|
if price.TempModel != nil {
|
||||||
err := db.QueryRow("SELECT COUNT(*) FROM price WHERE status = 'pending'").Scan(&pendingCount)
|
updateMap["model"] = *price.TempModel
|
||||||
if err != nil {
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count pending prices"})
|
if price.TempModelType != nil {
|
||||||
return
|
updateMap["model_type"] = *price.TempModelType
|
||||||
|
}
|
||||||
|
if price.TempBillingType != nil {
|
||||||
|
updateMap["billing_type"] = *price.TempBillingType
|
||||||
|
}
|
||||||
|
if price.TempChannelType != nil {
|
||||||
|
updateMap["channel_type"] = *price.TempChannelType
|
||||||
|
}
|
||||||
|
if price.TempCurrency != nil {
|
||||||
|
updateMap["currency"] = *price.TempCurrency
|
||||||
|
}
|
||||||
|
if price.TempInputPrice != nil {
|
||||||
|
updateMap["input_price"] = *price.TempInputPrice
|
||||||
|
}
|
||||||
|
if price.TempOutputPrice != nil {
|
||||||
|
updateMap["output_price"] = *price.TempOutputPrice
|
||||||
|
}
|
||||||
|
if price.TempPriceSource != nil {
|
||||||
|
updateMap["price_source"] = *price.TempPriceSource
|
||||||
}
|
}
|
||||||
|
|
||||||
if pendingCount == 0 {
|
// 清除所有临时字段
|
||||||
c.JSON(http.StatusOK, gin.H{
|
updateMap["temp_model"] = nil
|
||||||
"message": "No pending prices to approve",
|
updateMap["temp_model_type"] = nil
|
||||||
"count": 0,
|
updateMap["temp_billing_type"] = nil
|
||||||
})
|
updateMap["temp_channel_type"] = nil
|
||||||
|
updateMap["temp_currency"] = nil
|
||||||
|
updateMap["temp_input_price"] = nil
|
||||||
|
updateMap["temp_output_price"] = nil
|
||||||
|
updateMap["temp_price_source"] = nil
|
||||||
|
updateMap["updated_by"] = nil
|
||||||
|
|
||||||
|
if err := tx.Model(&price).Updates(updateMap).Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to approve prices"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量更新所有待审核的价格
|
|
||||||
result, err := db.Exec(`
|
|
||||||
UPDATE price
|
|
||||||
SET model = COALESCE(temp_model, model),
|
|
||||||
model_type = COALESCE(temp_model_type, model_type),
|
|
||||||
billing_type = COALESCE(temp_billing_type, billing_type),
|
|
||||||
channel_type = COALESCE(temp_channel_type, channel_type),
|
|
||||||
currency = COALESCE(temp_currency, currency),
|
|
||||||
input_price = COALESCE(temp_input_price, input_price),
|
|
||||||
output_price = COALESCE(temp_output_price, output_price),
|
|
||||||
price_source = COALESCE(temp_price_source, price_source),
|
|
||||||
status = ?,
|
|
||||||
updated_at = ?,
|
|
||||||
temp_model = NULL,
|
|
||||||
temp_model_type = NULL,
|
|
||||||
temp_billing_type = NULL,
|
|
||||||
temp_channel_type = NULL,
|
|
||||||
temp_currency = NULL,
|
|
||||||
temp_input_price = NULL,
|
|
||||||
temp_output_price = NULL,
|
|
||||||
temp_price_source = NULL,
|
|
||||||
updated_by = NULL
|
|
||||||
WHERE status = 'pending'`, input.Status, now)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to approve all prices"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedCount, _ := result.RowsAffected()
|
// 提交事务
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "All pending prices approved successfully",
|
"message": "All pending prices approved successfully",
|
||||||
"count": updatedCount,
|
"count": len(pendingPrices),
|
||||||
"updated_at": now,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user