wood chen da79bf3d6d 新增批量审核价格功能并优化前端交互
- 后端新增 `/prices/approve-all` 接口,支持管理员一键通过所有待审核价格
- 前端价格管理页面添加"全部通过"按钮,仅对管理员可见
- 优化批量添加价格页面交互,增加行复制和删除按钮
- 调整价格输入为可为空,移除默认值为0的限制
2025-03-06 22:43:30 +08:00

449 lines
13 KiB
Go

package handlers
import (
"database/sql"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"aimodels-prices/models"
)
func GetPrices(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
// 获取分页和筛选参数
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
channelType := c.Query("channel_type") // 厂商筛选参数
modelType := c.Query("model_type") // 模型类型筛选参数
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
offset := (page - 1) * pageSize
// 构建查询条件
var conditions []string
var args []interface{}
if channelType != "" {
conditions = append(conditions, "channel_type = ?")
args = append(args, channelType)
}
if modelType != "" {
conditions = append(conditions, "model_type = ?")
args = append(args, modelType)
}
// 组合WHERE子句
var whereClause string
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// 获取总数
var total int
countQuery := "SELECT COUNT(*) FROM price"
if whereClause != "" {
countQuery += " " + whereClause
}
err := db.QueryRow(countQuery, args...).Scan(&total)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count prices"})
return
}
// 使用分页查询
query := `
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_at, updated_at, created_by,
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price`
if whereClause != "" {
query += " " + whereClause
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, pageSize, offset)
rows, err := db.Query(query, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch prices"})
return
}
defer rows.Close()
var prices []models.Price
for rows.Next() {
var price models.Price
if err := rows.Scan(
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
return
}
prices = append(prices, price)
}
c.JSON(http.StatusOK, gin.H{
"total": total,
"prices": prices,
})
}
func CreatePrice(c *gin.Context) {
var price models.Price
if err := c.ShouldBindJSON(&price); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证模型厂商ID是否存在
db := c.MustGet("db").(*sql.DB)
var providerExists bool
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
if err != nil || !providerExists {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
now := time.Now()
result, err := db.Exec(`
INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_by, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
now, now)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create price"})
return
}
id, _ := result.LastInsertId()
price.ID = uint(id)
price.Status = "pending"
price.CreatedAt = now
price.UpdatedAt = now
c.JSON(http.StatusCreated, price)
}
func UpdatePriceStatus(c *gin.Context) {
id := c.Param("id")
var input struct {
Status string `json:"status" binding:"required,oneof=approved rejected"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
db := c.MustGet("db").(*sql.DB)
now := time.Now()
if input.Status == "approved" {
// 如果是批准,将临时字段的值更新到正式字段
_, err := db.Exec(`
UPDATE price
SET model = COALESCE(temp_model, model),
model_type = COALESCE(temp_model_type, model_type),
billing_type = COALESCE(temp_billing_type, billing_type),
channel_type = COALESCE(temp_channel_type, channel_type),
currency = COALESCE(temp_currency, currency),
input_price = COALESCE(temp_input_price, input_price),
output_price = COALESCE(temp_output_price, output_price),
price_source = COALESCE(temp_price_source, price_source),
status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
temp_input_price = NULL,
temp_output_price = NULL,
temp_price_source = NULL,
updated_by = NULL
WHERE id = ?`, input.Status, now, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
return
}
} else {
// 如果是拒绝,清除临时字段
_, err := db.Exec(`
UPDATE price
SET status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
temp_input_price = NULL,
temp_output_price = NULL,
temp_price_source = NULL,
updated_by = NULL
WHERE id = ?`, input.Status, now, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price status"})
return
}
}
c.JSON(http.StatusOK, gin.H{
"message": "Status updated successfully",
"status": input.Status,
"updated_at": now,
})
}
// UpdatePrice 更新价格
func UpdatePrice(c *gin.Context) {
id := c.Param("id")
var price models.Price
if err := c.ShouldBindJSON(&price); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证模型厂商ID是否存在
db := c.MustGet("db").(*sql.DB)
var providerExists bool
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM provider WHERE id = ?)", price.ChannelType).Scan(&providerExists)
if err != nil || !providerExists {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider ID"})
return
}
// 获取当前用户
user, exists := c.Get("user")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
return
}
currentUser := user.(*models.User)
now := time.Now()
var query string
var args []interface{}
// 根据用户角色决定更新方式
if currentUser.Role == "admin" {
// 管理员直接更新主字段
query = `
UPDATE price
SET model = ?, model_type = ?, billing_type = ?, channel_type = ?, currency = ?,
input_price = ?, output_price = ?, price_source = ?,
updated_by = ?, updated_at = ?, status = 'approved',
temp_model = NULL, temp_model_type = NULL, temp_billing_type = NULL,
temp_channel_type = NULL, temp_currency = NULL, temp_input_price = NULL,
temp_output_price = NULL, temp_price_source = NULL
WHERE id = ?`
args = []interface{}{
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id,
}
} else {
// 普通用户更新临时字段
query = `
UPDATE price
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?,
temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
updated_by = ?, updated_at = ?, status = 'pending'
WHERE id = ?`
args = []interface{}{
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id,
}
}
_, err = db.Exec(query, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
return
}
// 获取更新后的价格信息
err = db.QueryRow(`
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_at, updated_at, created_by,
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price WHERE id = ?`, id).Scan(
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"})
return
}
c.JSON(http.StatusOK, price)
}
// DeletePrice 删除价格
func DeletePrice(c *gin.Context) {
id := c.Param("id")
db := c.MustGet("db").(*sql.DB)
_, err := db.Exec("DELETE FROM price WHERE id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete price"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Price deleted successfully"})
}
// PriceRate 价格倍率结构
type PriceRate struct {
Model string `json:"model"`
ModelType string `json:"model_type"`
Type string `json:"type"`
ChannelType uint `json:"channel_type"`
Input float64 `json:"input"`
Output float64 `json:"output"`
}
// GetPriceRates 获取价格倍率
func GetPriceRates(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
rows, err := db.Query(`
SELECT model, model_type, billing_type, channel_type,
CASE
WHEN currency = 'USD' THEN input_price / 2
ELSE input_price / 14
END as input_rate,
CASE
WHEN currency = 'USD' THEN output_price / 2
ELSE output_price / 14
END as output_rate
FROM price
WHERE status = 'approved'
ORDER BY model, channel_type`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch price rates"})
return
}
defer rows.Close()
var rates []PriceRate
for rows.Next() {
var rate PriceRate
if err := rows.Scan(
&rate.Model,
&rate.ModelType,
&rate.Type,
&rate.ChannelType,
&rate.Input,
&rate.Output); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price rate"})
return
}
rates = append(rates, rate)
}
c.JSON(http.StatusOK, rates)
}
// ApproveAllPrices 批量通过所有待审核的价格
func ApproveAllPrices(c *gin.Context) {
var input struct {
Status string `json:"status" binding:"required,eq=approved"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
db := c.MustGet("db").(*sql.DB)
now := time.Now()
// 获取当前用户
user, exists := c.Get("user")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
return
}
currentUser := user.(*models.User)
// 只有管理员可以批量通过
if currentUser.Role != "admin" {
c.JSON(http.StatusForbidden, gin.H{"error": "Admin permission required"})
return
}
// 查询待审核的价格数量
var pendingCount int
err := db.QueryRow("SELECT COUNT(*) FROM price WHERE status = 'pending'").Scan(&pendingCount)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to count pending prices"})
return
}
if pendingCount == 0 {
c.JSON(http.StatusOK, gin.H{
"message": "No pending prices to approve",
"count": 0,
})
return
}
// 批量更新所有待审核的价格
result, err := db.Exec(`
UPDATE price
SET model = COALESCE(temp_model, model),
model_type = COALESCE(temp_model_type, model_type),
billing_type = COALESCE(temp_billing_type, billing_type),
channel_type = COALESCE(temp_channel_type, channel_type),
currency = COALESCE(temp_currency, currency),
input_price = COALESCE(temp_input_price, input_price),
output_price = COALESCE(temp_output_price, output_price),
price_source = COALESCE(temp_price_source, price_source),
status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
temp_input_price = NULL,
temp_output_price = NULL,
temp_price_source = NULL,
updated_by = NULL
WHERE status = 'pending'`, input.Status, now)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to approve all prices"})
return
}
updatedCount, _ := result.RowsAffected()
c.JSON(http.StatusOK, gin.H{
"message": "All pending prices approved successfully",
"count": updatedCount,
"updated_at": now,
})
}