From 2efb33fc2f4568c6100f0eef46d5c6c5ffe9b18b Mon Sep 17 00:00:00 2001 From: wood chen Date: Fri, 21 Feb 2025 12:07:56 +0800 Subject: [PATCH] Enhance provider management and price filtering with model type and ID handling --- backend/handlers/prices.go | 66 ++++++++++++++++++++------ backend/handlers/providers.go | 80 +++++++++++++++++++++++--------- frontend/src/views/Prices.vue | 39 +++++++++++++++- frontend/src/views/Providers.vue | 38 +++++++-------- 4 files changed, 166 insertions(+), 57 deletions(-) diff --git a/backend/handlers/prices.go b/backend/handlers/prices.go index 57dc73a..488ae90 100644 --- a/backend/handlers/prices.go +++ b/backend/handlers/prices.go @@ -4,6 +4,7 @@ import ( "database/sql" "net/http" "strconv" + "strings" "time" "github.com/gin-gonic/gin" @@ -17,7 +18,8 @@ func GetPrices(c *gin.Context) { // 获取分页和筛选参数 page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) - channelType := c.Query("channel_type") // 新增: 获取厂商筛选参数 + channelType := c.Query("channel_type") // 厂商筛选参数 + modelType := c.Query("model_type") // 模型类型筛选参数 if page < 1 { page = 1 @@ -29,12 +31,23 @@ func GetPrices(c *gin.Context) { offset := (page - 1) * pageSize // 构建查询条件 - var whereClause string + var conditions []string var args []interface{} + if channelType != "" { - whereClause = "WHERE channel_type = ?" + conditions = append(conditions, "channel_type = ?") args = append(args, channelType) } + if modelType != "" { + conditions = append(conditions, "model_type = ?") + args = append(args, modelType) + } + + // 组合WHERE子句 + var whereClause string + if len(conditions) > 0 { + whereClause = "WHERE " + strings.Join(conditions, " AND ") + } // 获取总数 var total int @@ -225,16 +238,43 @@ func UpdatePrice(c *gin.Context) { currentUser := user.(*models.User) now := time.Now() - // 将新的价格信息存储到临时字段 - _, err = db.Exec(` - UPDATE price - SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?, - temp_input_price = ?, temp_output_price = ?, temp_price_source = ?, - updated_by = ?, updated_at = ?, status = 'pending' - WHERE id = ?`, - price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, - price.InputPrice, price.OutputPrice, price.PriceSource, - currentUser.Username, now, id) + + var query string + var args []interface{} + + // 根据用户角色决定更新方式 + if currentUser.Role == "admin" { + // 管理员直接更新主字段 + query = ` + UPDATE price + SET model = ?, model_type = ?, billing_type = ?, channel_type = ?, currency = ?, + input_price = ?, output_price = ?, price_source = ?, + updated_by = ?, updated_at = ?, status = 'approved', + temp_model = NULL, temp_model_type = NULL, temp_billing_type = NULL, + temp_channel_type = NULL, temp_currency = NULL, temp_input_price = NULL, + temp_output_price = NULL, temp_price_source = NULL + WHERE id = ?` + args = []interface{}{ + price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, + price.InputPrice, price.OutputPrice, price.PriceSource, + currentUser.Username, now, id, + } + } else { + // 普通用户更新临时字段 + query = ` + UPDATE price + SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, + temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?, + updated_by = ?, updated_at = ?, status = 'pending' + WHERE id = ?` + args = []interface{}{ + price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency, + price.InputPrice, price.OutputPrice, price.PriceSource, + currentUser.Username, now, id, + } + } + + _, err = db.Exec(query, args...) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"}) return diff --git a/backend/handlers/providers.go b/backend/handlers/providers.go index baa1b9f..107e40d 100644 --- a/backend/handlers/providers.go +++ b/backend/handlers/providers.go @@ -82,12 +82,7 @@ func CreateProvider(c *gin.Context) { // UpdateProvider 更新模型厂商 func UpdateProvider(c *gin.Context) { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid ID"}) - return - } - + oldID := c.Param("id") var provider models.Provider if err := c.ShouldBindJSON(&provider); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -95,25 +90,68 @@ func UpdateProvider(c *gin.Context) { } db := c.MustGet("db").(*sql.DB) - now := time.Now() - _, err = db.Exec(` - UPDATE provider - SET name = ?, icon = ?, updated_at = ? - WHERE id = ?`, - provider.Name, provider.Icon, now, id) + + // 开始事务 + tx, err := db.Begin() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"}) return } - // 获取更新后的模型厂商信息 - err = db.QueryRow(` - SELECT id, name, icon, created_at, updated_at, created_by - FROM provider WHERE id = ?`, id).Scan( - &provider.ID, &provider.Name, &provider.Icon, - &provider.CreatedAt, &provider.UpdatedAt, &provider.CreatedBy) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated provider"}) + // 如果ID发生变化,需要同时更新price表中的引用 + if oldID != strconv.FormatUint(uint64(provider.ID), 10) { + // 更新price表中的channel_type + _, err = tx.Exec("UPDATE price SET channel_type = ? WHERE channel_type = ?", provider.ID, oldID) + if err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price references"}) + return + } + + // 更新price表中的temp_channel_type + _, err = tx.Exec("UPDATE price SET temp_channel_type = ? WHERE temp_channel_type = ?", provider.ID, oldID) + if err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price temp references"}) + return + } + + // 删除旧记录 + _, err = tx.Exec("DELETE FROM provider WHERE id = ?", oldID) + if err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old provider"}) + return + } + + // 插入新记录 + _, err = tx.Exec(` + INSERT INTO provider (id, name, icon, created_at, updated_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + `, provider.ID, provider.Name, provider.Icon) + if err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new provider"}) + return + } + } else { + // 如果ID没有变化,直接更新 + _, err = tx.Exec(` + UPDATE provider + SET name = ?, icon = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, provider.Name, provider.Icon, oldID) + if err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"}) + return + } + } + + // 提交事务 + if err := tx.Commit(); err != nil { + tx.Rollback() + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"}) return } diff --git a/frontend/src/views/Prices.vue b/frontend/src/views/Prices.vue index 5097cd6..40f2e7c 100644 --- a/frontend/src/views/Prices.vue +++ b/frontend/src/views/Prices.vue @@ -43,6 +43,24 @@ +
+
模型类别:
+
+ 全部 + + {{ label }} + +
+
+