重构价格管理模块,完全迁移到 GORM ORM

- 将价格相关处理函数从原生 SQL 完全迁移到 GORM
- 优化数据库查询逻辑,使用 GORM 的链式查询和更新方法
- 重构价格审核、更新和批量审核功能,使用事务处理
- 简化数据库操作,移除手动 SQL 查询和扫描逻辑
- 优化价格倍率计算方法,提高代码可读性和性能
This commit is contained in:
wood chen 2025-03-06 23:46:18 +08:00
parent 31f65a9301
commit 449f95d1b5
2 changed files with 276 additions and 288 deletions

View File

@ -81,10 +81,11 @@ func UpdateModelType(c *gin.Context) {
return return
} }
} else { } else {
// 直接更新 // 直接更新,只更新特定字段
existingType.TypeLabel = updateType.TypeLabel if err := database.DB.Model(&existingType).Updates(map[string]interface{}{
existingType.SortOrder = updateType.SortOrder "type_label": updateType.TypeLabel,
if err := database.DB.Save(&existingType).Error; err != nil { "sort_order": updateType.SortOrder,
}).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update model type"})
return return
} }

View File

@ -1,20 +1,17 @@
package handlers package handlers
import ( import (
"database/sql"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"aimodels-prices/database"
"aimodels-prices/models" "aimodels-prices/models"
) )
func GetPrices(c *gin.Context) { func GetPrices(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
// 获取分页和筛选参数 // 获取分页和筛选参数
page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
@ -30,75 +27,34 @@ func GetPrices(c *gin.Context) {
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
// 构建查询条件 // 构建查询
var conditions []string query := database.DB.Model(&models.Price{})
var args []interface{}
// 添加筛选条件
if channelType != "" { if channelType != "" {
conditions = append(conditions, "channel_type = ?") query = query.Where("channel_type = ?", channelType)
args = append(args, channelType)
} }
if modelType != "" { if modelType != "" {
conditions = append(conditions, "model_type = ?") query = query.Where("model_type = ?", modelType)
args = append(args, modelType)
}
// 组合WHERE子句
var whereClause string
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
} }
// 获取总数 // 获取总数
var total int var total int64
countQuery := "SELECT COUNT(*) FROM price" if err := query.Count(&total).Error; err != nil {
if whereClause != "" {
countQuery += " " + whereClause
}
err := db.QueryRow(countQuery, args...).Scan(&total)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"})
return return
} }
// 使用分页查询 // 获取分页数据
query := ` var prices []models.Price
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price, if err := query.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&prices).Error; err != nil {
price_source, status, created_at, updated_at, created_by,
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price`
if whereClause != "" {
query += " " + whereClause
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, pageSize, offset)
rows, err := db.Query(query, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
return return
} }
defer rows.Close()
var prices []models.Price
for rows.Next() {
var price models.Price
if err := rows.Scan(
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
return
}
prices = append(prices, price)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"total": total, "total": total,
"prices": prices, "data": prices,
}) })
} }
@ -110,46 +66,33 @@ func CreatePrice(c *gin.Context) {
} }
// 验证模型厂商ID是否存在 // 验证模型厂商ID是否存在
db := c.MustGet("db").(*sql.DB) var provider models.Provider
var providerExists bool if err := database.DB.Where("id = ?", price.ChannelType).First(&provider).Error; err != nil {
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
if err != nil || !providerExists {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return return
} }
// 检查同一厂商下是否已存在相同名称的模型 // 检查同一厂商下是否已存在相同名称的模型
var modelExists bool var count int64
err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM price WHERE channel_type = ? AND model = ? AND status = 'approved')", if err := database.DB.Model(&models.Price{}).Where("channel_type = ? AND model = ? AND status = 'approved'",
price.ChannelType, price.Model).Scan(&modelExists) price.ChannelType, price.Model).Count(&count).Error; err != nil {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
return return
} }
if modelExists { if count > 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
return return
} }
now := time.Now() // 设置状态和创建者
result, err := db.Exec(` price.Status = "pending"
INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_by, created_at, updated_at) // 创建记录
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`, if err := database.DB.Create(&price).Error; err != nil {
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
now, now)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create price"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create price"})
return return
} }
id, _ := result.LastInsertId()
price.ID = uint(id)
price.Status = "pending"
price.CreatedAt = now
price.UpdatedAt = now
c.JSON(http.StatusCreated, price) c.JSON(http.StatusCreated, price)
} }
@ -164,67 +107,100 @@ func UpdatePriceStatus(c *gin.Context) {
return return
} }
db := c.MustGet("db").(*sql.DB) // 查找价格记录
now := time.Now() var price models.Price
if err := database.DB.Where("id = ?", id).First(&price).Error; err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
return
}
// 开始事务
tx := database.DB.Begin()
if input.Status == "approved" { if input.Status == "approved" {
// 如果是批准,将临时字段的值更新到正式字段 // 如果是批准,将临时字段的值更新到正式字段
_, err := db.Exec(` updateMap := map[string]interface{}{
UPDATE price "status": input.Status,
SET model = COALESCE(temp_model, model), "updated_at": time.Now(),
model_type = COALESCE(temp_model_type, model_type), }
billing_type = COALESCE(temp_billing_type, billing_type),
channel_type = COALESCE(temp_channel_type, channel_type), // 如果临时字段有值,则更新主字段
currency = COALESCE(temp_currency, currency), if price.TempModel != nil {
input_price = COALESCE(temp_input_price, input_price), updateMap["model"] = *price.TempModel
output_price = COALESCE(temp_output_price, output_price), }
price_source = COALESCE(temp_price_source, price_source), if price.TempModelType != nil {
status = ?, updateMap["model_type"] = *price.TempModelType
updated_at = ?, }
temp_model = NULL, if price.TempBillingType != nil {
temp_model_type = NULL, updateMap["billing_type"] = *price.TempBillingType
temp_billing_type = NULL, }
temp_channel_type = NULL, if price.TempChannelType != nil {
temp_currency = NULL, updateMap["channel_type"] = *price.TempChannelType
temp_input_price = NULL, }
temp_output_price = NULL, if price.TempCurrency != nil {
temp_price_source = NULL, updateMap["currency"] = *price.TempCurrency
updated_by = NULL }
WHERE id = ?`, input.Status, now, id) if price.TempInputPrice != nil {
if err != nil { updateMap["input_price"] = *price.TempInputPrice
}
if price.TempOutputPrice != nil {
updateMap["output_price"] = *price.TempOutputPrice
}
if price.TempPriceSource != nil {
updateMap["price_source"] = *price.TempPriceSource
}
// 清除所有临时字段
updateMap["temp_model"] = nil
updateMap["temp_model_type"] = nil
updateMap["temp_billing_type"] = nil
updateMap["temp_channel_type"] = nil
updateMap["temp_currency"] = nil
updateMap["temp_input_price"] = nil
updateMap["temp_output_price"] = nil
updateMap["temp_price_source"] = nil
updateMap["updated_by"] = nil
if err := tx.Model(&price).Updates(updateMap).Error; err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
return return
} }
} else { } else {
// 如果是拒绝,清除临时字段 // 如果是拒绝,清除临时字段
_, err := db.Exec(` if err := tx.Model(&price).Updates(map[string]interface{}{
UPDATE price "status": input.Status,
SET status = ?, "updated_at": time.Now(),
updated_at = ?, "temp_model": nil,
temp_model = NULL, "temp_model_type": nil,
temp_model_type = NULL, "temp_billing_type": nil,
temp_billing_type = NULL, "temp_channel_type": nil,
temp_channel_type = NULL, "temp_currency": nil,
temp_currency = NULL, "temp_input_price": nil,
temp_input_price = NULL, "temp_output_price": nil,
temp_output_price = NULL, "temp_price_source": nil,
temp_price_source = NULL, "updated_by": nil,
updated_by = NULL }).Error; err != nil {
WHERE id = ?`, input.Status, now, id) tx.Rollback()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
return return
} }
} }
// 提交事务
if err := tx.Commit().Error; err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
return
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "Status updated successfully", "message": "Status updated successfully",
"status": input.Status, "status": input.Status,
"updated_at": now, "updated_at": time.Now(),
}) })
} }
// UpdatePrice 更新价格
func UpdatePrice(c *gin.Context) { func UpdatePrice(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
var price models.Price var price models.Price
@ -234,23 +210,20 @@ func UpdatePrice(c *gin.Context) {
} }
// 验证模型厂商ID是否存在 // 验证模型厂商ID是否存在
db := c.MustGet("db").(*sql.DB) var provider models.Provider
var providerExists bool if err := database.DB.Where("id = ?", price.ChannelType).First(&provider).Error; err != nil {
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
if err != nil || !providerExists {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return return
} }
// 检查同一厂商下是否已存在相同名称的模型(排除当前正在编辑的记录) // 检查同一厂商下是否已存在相同名称的模型(排除当前正在编辑的记录)
var modelExists bool var count int64
err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM price WHERE channel_type = ? AND model = ? AND id != ? AND status = 'approved')", if err := database.DB.Model(&models.Price{}).Where("channel_type = ? AND model = ? AND id != ? AND status = 'approved'",
price.ChannelType, price.Model, id).Scan(&modelExists) price.ChannelType, price.Model, id).Count(&count).Error; err != nil {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check model existence"})
return return
} }
if modelExists { if count > 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Model with the same name already exists for this provider"})
return return
} }
@ -263,76 +236,73 @@ func UpdatePrice(c *gin.Context) {
} }
currentUser := user.(*models.User) currentUser := user.(*models.User)
now := time.Now() // 查找现有记录
var existingPrice models.Price
var query string if err := database.DB.Where("id = ?", id).First(&existingPrice).Error; err != nil {
var args []interface{} c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
return
}
// 根据用户角色决定更新方式 // 根据用户角色决定更新方式
if currentUser.Role == "admin" { if currentUser.Role == "admin" {
// 管理员直接更新主字段 // 管理员直接更新主字段
query = ` existingPrice.Model = price.Model
UPDATE price existingPrice.ModelType = price.ModelType
SET model = ?, model_type = ?, billing_type = ?, channel_type = ?, currency = ?, existingPrice.BillingType = price.BillingType
input_price = ?, output_price = ?, price_source = ?, existingPrice.ChannelType = price.ChannelType
updated_by = ?, updated_at = ?, status = 'approved', existingPrice.Currency = price.Currency
temp_model = NULL, temp_model_type = NULL, temp_billing_type = NULL, existingPrice.InputPrice = price.InputPrice
temp_channel_type = NULL, temp_currency = NULL, temp_input_price = NULL, existingPrice.OutputPrice = price.OutputPrice
temp_output_price = NULL, temp_price_source = NULL existingPrice.PriceSource = price.PriceSource
WHERE id = ?` existingPrice.Status = "approved"
args = []interface{}{ existingPrice.UpdatedBy = &currentUser.Username
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, existingPrice.TempModel = nil
price.InputPrice, price.OutputPrice, price.PriceSource, existingPrice.TempModelType = nil
currentUser.Username, now, id, existingPrice.TempBillingType = nil
} existingPrice.TempChannelType = nil
} else { existingPrice.TempCurrency = nil
// 普通用户更新临时字段 existingPrice.TempInputPrice = nil
query = ` existingPrice.TempOutputPrice = nil
UPDATE price existingPrice.TempPriceSource = nil
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?,
temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
updated_by = ?, updated_at = ?, status = 'pending'
WHERE id = ?`
args = []interface{}{
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id,
}
}
_, err = db.Exec(query, args...) if err := database.DB.Save(&existingPrice).Error; err != nil {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
return return
} }
} else {
// 普通用户更新临时字段
existingPrice.TempModel = &price.Model
existingPrice.TempModelType = &price.ModelType
existingPrice.TempBillingType = &price.BillingType
existingPrice.TempChannelType = &price.ChannelType
existingPrice.TempCurrency = &price.Currency
existingPrice.TempInputPrice = &price.InputPrice
existingPrice.TempOutputPrice = &price.OutputPrice
existingPrice.TempPriceSource = &price.PriceSource
existingPrice.Status = "pending"
existingPrice.UpdatedBy = &currentUser.Username
// 获取更新后的价格信息 if err := database.DB.Save(&existingPrice).Error; err != nil {
err = db.QueryRow(` c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price, return
price_source, status, created_at, updated_at, created_by, }
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency, }
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price WHERE id = ?`, id).Scan( c.JSON(http.StatusOK, existingPrice)
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency, }
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy, func DeletePrice(c *gin.Context) {
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency, id := c.Param("id")
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
if err != nil { // 查找价格记录
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"}) var price models.Price
if err := database.DB.Where("id = ?", id).First(&price).Error; err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Price not found"})
return return
} }
c.JSON(http.StatusOK, price) // 删除记录
} if err := database.DB.Delete(&price).Error; err != nil {
// DeletePrice 删除价格
func DeletePrice(c *gin.Context) {
id := c.Param("id")
db := c.MustGet("db").(*sql.DB)
_, err := db.Exec("DELETE FROM price WHERE id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete price"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete price"})
return return
} }
@ -352,123 +322,140 @@ type PriceRate struct {
// GetPriceRates 获取价格倍率 // GetPriceRates 获取价格倍率
func GetPriceRates(c *gin.Context) { func GetPriceRates(c *gin.Context) {
db := c.MustGet("db").(*sql.DB) var prices []models.Price
rows, err := db.Query(` if err := database.DB.Where("status = 'approved'").Find(&prices).Error; err != nil {
SELECT model, model_type, billing_type, channel_type, c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
CASE
WHEN currency = 'USD' THEN input_price / 2
ELSE input_price / 14
END as input_rate,
CASE
WHEN currency = 'USD' THEN output_price / 2
ELSE output_price / 14
END as output_rate
FROM price
WHERE status = 'approved'
ORDER BY model, channel_type`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch price rates"})
return return
} }
defer rows.Close()
var rates []PriceRate // 按模型分组
for rows.Next() { modelMap := make(map[string]map[uint]models.Price)
var rate PriceRate for _, price := range prices {
if err := rows.Scan( if _, exists := modelMap[price.Model]; !exists {
&rate.Model, modelMap[price.Model] = make(map[uint]models.Price)
&rate.ModelType, }
&rate.Type, modelMap[price.Model][price.ChannelType] = price
&rate.ChannelType, }
&rate.Input,
&rate.Output); err != nil { // 计算倍率
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price rate"}) var rates []PriceRate
return for model, providers := range modelMap {
// 找出基准价格通常是OpenAI的价格
var basePrice models.Price
var found bool
for _, price := range providers {
if price.ChannelType == 1 { // 假设OpenAI的ID是1
basePrice = price
found = true
break
}
}
if !found {
continue
}
// 计算其他厂商相对于基准价格的倍率
for channelType, price := range providers {
if channelType == 1 {
continue // 跳过基准价格
}
// 计算输入和输出的倍率
inputRate := 0.0
if basePrice.InputPrice > 0 {
inputRate = price.InputPrice / basePrice.InputPrice
}
outputRate := 0.0
if basePrice.OutputPrice > 0 {
outputRate = price.OutputPrice / basePrice.OutputPrice
}
rates = append(rates, PriceRate{
Model: model,
ModelType: price.ModelType,
Type: price.BillingType,
ChannelType: channelType,
Input: inputRate,
Output: outputRate,
})
} }
rates = append(rates, rate)
} }
c.JSON(http.StatusOK, rates) c.JSON(http.StatusOK, rates)
} }
// ApproveAllPrices 批量通过所有待审核的价格
func ApproveAllPrices(c *gin.Context) { func ApproveAllPrices(c *gin.Context) {
var input struct { // 查找所有待审核的价格
Status string `json:"status" binding:"required,eq=approved"` var pendingPrices []models.Price
} if err := database.DB.Where("status = 'pending'").Find(&pendingPrices).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch pending prices"})
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
db := c.MustGet("db").(*sql.DB) // 开始事务
now := time.Now() tx := database.DB.Begin()
// 获取当前用户 for _, price := range pendingPrices {
user, exists := c.Get("user") updateMap := map[string]interface{}{
if !exists { "status": "approved",
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"}) "updated_at": time.Now(),
return
}
currentUser := user.(*models.User)
// 只有管理员可以批量通过
if currentUser.Role != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin permission required"})
return
} }
// 查询待审核的价格数量 // 如果临时字段有值,则更新主字段
var pendingCount int if price.TempModel != nil {
err := db.QueryRow("SELECT COUNT(*) FROM price WHERE status = 'pending'").Scan(&pendingCount) updateMap["model"] = *price.TempModel
if err != nil { }
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count pending prices"}) if price.TempModelType != nil {
return updateMap["model_type"] = *price.TempModelType
}
if price.TempBillingType != nil {
updateMap["billing_type"] = *price.TempBillingType
}
if price.TempChannelType != nil {
updateMap["channel_type"] = *price.TempChannelType
}
if price.TempCurrency != nil {
updateMap["currency"] = *price.TempCurrency
}
if price.TempInputPrice != nil {
updateMap["input_price"] = *price.TempInputPrice
}
if price.TempOutputPrice != nil {
updateMap["output_price"] = *price.TempOutputPrice
}
if price.TempPriceSource != nil {
updateMap["price_source"] = *price.TempPriceSource
} }
if pendingCount == 0 { // 清除所有临时字段
c.JSON(http.StatusOK, gin.H{ updateMap["temp_model"] = nil
"message": "No pending prices to approve", updateMap["temp_model_type"] = nil
"count": 0, updateMap["temp_billing_type"] = nil
}) updateMap["temp_channel_type"] = nil
updateMap["temp_currency"] = nil
updateMap["temp_input_price"] = nil
updateMap["temp_output_price"] = nil
updateMap["temp_price_source"] = nil
updateMap["updated_by"] = nil
if err := tx.Model(&price).Updates(updateMap).Error; err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to approve prices"})
return return
} }
// 批量更新所有待审核的价格
result, err := db.Exec(`
UPDATE price
SET model = COALESCE(temp_model, model),
model_type = COALESCE(temp_model_type, model_type),
billing_type = COALESCE(temp_billing_type, billing_type),
channel_type = COALESCE(temp_channel_type, channel_type),
currency = COALESCE(temp_currency, currency),
input_price = COALESCE(temp_input_price, input_price),
output_price = COALESCE(temp_output_price, output_price),
price_source = COALESCE(temp_price_source, price_source),
status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
temp_input_price = NULL,
temp_output_price = NULL,
temp_price_source = NULL,
updated_by = NULL
WHERE status = 'pending'`, input.Status, now)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to approve all prices"})
return
} }
updatedCount, _ := result.RowsAffected() // 提交事务
if err := tx.Commit().Error; err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
return
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "All pending prices approved successfully", "message": "All pending prices approved successfully",
"count": updatedCount, "count": len(pendingPrices),
"updated_at": now,
}) })
} }