Enhance provider management and price filtering with model type and ID handling

This commit is contained in:
wood chen 2025-02-21 12:07:56 +08:00
parent 15a7c75145
commit 2efb33fc2f
4 changed files with 166 additions and 57 deletions

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
@ -17,7 +18,8 @@ func GetPrices(c *gin.Context) {
// 获取分页和筛选参数
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
channelType := c.Query("channel_type") // 新增: 获取厂商筛选参数
channelType := c.Query("channel_type") // 厂商筛选参数
modelType := c.Query("model_type") // 模型类型筛选参数
if page < 1 {
page = 1
@ -29,12 +31,23 @@ func GetPrices(c *gin.Context) {
offset := (page - 1) * pageSize
// 构建查询条件
var whereClause string
var conditions []string
var args []interface{}
if channelType != "" {
whereClause = "WHERE channel_type = ?"
conditions = append(conditions, "channel_type = ?")
args = append(args, channelType)
}
if modelType != "" {
conditions = append(conditions, "model_type = ?")
args = append(args, modelType)
}
// 组合WHERE子句
var whereClause string
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// 获取总数
var total int
@ -225,16 +238,43 @@ func UpdatePrice(c *gin.Context) {
currentUser := user.(*models.User)
now := time.Now()
// 将新的价格信息存储到临时字段
_, err = db.Exec(`
UPDATE price
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?,
temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
updated_by = ?, updated_at = ?, status = 'pending'
WHERE id = ?`,
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id)
var query string
var args []interface{}
// 根据用户角色决定更新方式
if currentUser.Role == "admin" {
// 管理员直接更新主字段
query = `
UPDATE price
SET model = ?, model_type = ?, billing_type = ?, channel_type = ?, currency = ?,
input_price = ?, output_price = ?, price_source = ?,
updated_by = ?, updated_at = ?, status = 'approved',
temp_model = NULL, temp_model_type = NULL, temp_billing_type = NULL,
temp_channel_type = NULL, temp_currency = NULL, temp_input_price = NULL,
temp_output_price = NULL, temp_price_source = NULL
WHERE id = ?`
args = []interface{}{
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id,
}
} else {
// 普通用户更新临时字段
query = `
UPDATE price
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?,
temp_currency = ?, temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
updated_by = ?, updated_at = ?, status = 'pending'
WHERE id = ?`
args = []interface{}{
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id,
}
}
_, err = db.Exec(query, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price"})
return

View File

@ -82,12 +82,7 @@ func CreateProvider(c *gin.Context) {
// UpdateProvider 更新模型厂商
func UpdateProvider(c *gin.Context) {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid ID"})
return
}
oldID := c.Param("id")
var provider models.Provider
if err := c.ShouldBindJSON(&provider); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -95,25 +90,68 @@ func UpdateProvider(c *gin.Context) {
}
db := c.MustGet("db").(*sql.DB)
now := time.Now()
_, err = db.Exec(`
UPDATE provider
SET name = ?, icon = ?, updated_at = ?
WHERE id = ?`,
provider.Name, provider.Icon, now, id)
// 开始事务
tx, err := db.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to begin transaction"})
return
}
// 获取更新后的模型厂商信息
err = db.QueryRow(`
SELECT id, name, icon, created_at, updated_at, created_by
FROM provider WHERE id = ?`, id).Scan(
&provider.ID, &provider.Name, &provider.Icon,
&provider.CreatedAt, &provider.UpdatedAt, &provider.CreatedBy)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated provider"})
// 如果ID发生变化需要同时更新price表中的引用
if oldID != strconv.FormatUint(uint64(provider.ID), 10) {
// 更新price表中的channel_type
_, err = tx.Exec("UPDATE price SET channel_type = ? WHERE channel_type = ?", provider.ID, oldID)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price references"})
return
}
// 更新price表中的temp_channel_type
_, err = tx.Exec("UPDATE price SET temp_channel_type = ? WHERE temp_channel_type = ?", provider.ID, oldID)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update price temp references"})
return
}
// 删除旧记录
_, err = tx.Exec("DELETE FROM provider WHERE id = ?", oldID)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete old provider"})
return
}
// 插入新记录
_, err = tx.Exec(`
INSERT INTO provider (id, name, icon, created_at, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
`, provider.ID, provider.Name, provider.Icon)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create new provider"})
return
}
} else {
// 如果ID没有变化直接更新
_, err = tx.Exec(`
UPDATE provider
SET name = ?, icon = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`, provider.Name, provider.Icon, oldID)
if err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update provider"})
return
}
}
// 提交事务
if err := tx.Commit(); err != nil {
tx.Rollback()
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
return
}

View File

@ -43,6 +43,24 @@
</div>
</div>
<div class="filter-section">
<div class="filter-label">模型类别:</div>
<div class="model-type-filters">
<el-button
:type="!selectedModelType ? 'primary' : ''"
@click="selectedModelType = ''"
>全部</el-button>
<el-button
v-for="(label, key) in modelTypeMap"
:key="key"
:type="selectedModelType === key ? 'primary' : ''"
@click="selectedModelType = key"
>
{{ label }}
</el-button>
</div>
</div>
<!-- 添加骨架屏 -->
<template v-if="loading">
<div v-for="i in 5" :key="i" class="skeleton-row">
@ -444,6 +462,7 @@ const form = ref({
})
const router = useRouter()
const selectedProvider = ref('')
const selectedModelType = ref('')
const isAdmin = computed(() => props.user?.role === 'admin')
@ -497,10 +516,13 @@ const loadPrices = async () => {
pageSize: pageSize.value
}
//
//
if (selectedProvider.value) {
params.channel_type = selectedProvider.value
}
if (selectedModelType.value) {
params.model_type = selectedModelType.value
}
try {
const [pricesRes, providersRes] = await Promise.all([
@ -513,7 +535,7 @@ const loadPrices = async () => {
providers.value = providersRes.data
//
const cacheKey = `${currentPage.value}-${pageSize.value}-${selectedProvider.value}`
const cacheKey = `${currentPage.value}-${pageSize.value}-${selectedProvider.value}-${selectedModelType.value}`
cachedPrices.value.set(cacheKey, {
prices: pricesRes.data.prices,
total: pricesRes.data.total
@ -931,6 +953,12 @@ watch(selectedProvider, () => {
loadPrices()
})
//
watch(selectedModelType, () => {
currentPage.value = 1 //
loadPrices()
})
onMounted(async () => {
await loadModelTypes()
await loadPrices()
@ -968,6 +996,13 @@ onMounted(async () => {
align-items: center;
}
.model-type-filters {
display: flex;
flex-wrap: wrap;
gap: 8px;
align-items: center;
}
:deep(.el-button) {
margin: 0;
}

View File

@ -38,16 +38,16 @@
</el-table>
</el-card>
<el-dialog v-model="dialogVisible" :title="dialogTitle">
<el-form :model="form" label-width="80px">
<el-form-item label="ID" v-if="!editingProvider">
<el-input-number v-model="form.id" :min="1" />
<el-dialog v-model="dialogVisible" :title="editingProvider ? '编辑模型厂商' : '添加模型厂商'" width="500px">
<el-form :model="form" label-width="100px">
<el-form-item label="ID">
<el-input v-model="form.id" placeholder="请输入厂商ID" />
</el-form-item>
<el-form-item label="名称">
<el-input v-model="form.name" />
<el-input v-model="form.name" placeholder="请输入厂商名称" />
</el-form-item>
<el-form-item label="图标链接">
<el-input v-model="form.icon" />
<el-form-item label="图标">
<el-input v-model="form.icon" placeholder="请输入图标URL" />
</el-form-item>
</el-form>
<template #footer>
@ -74,7 +74,7 @@ const providers = ref([])
const dialogVisible = ref(false)
const editingProvider = ref(null)
const form = ref({
id: 1,
id: '',
name: '',
icon: ''
})
@ -109,19 +109,9 @@ onMounted(() => {
loadProviders()
})
const dialogTitle = computed(() => {
if (editingProvider.value) {
return '编辑模型厂商'
}
return '添加模型厂商'
})
const handleEdit = (provider) => {
if (!isAdmin.value) {
ElMessage.warning('只有管理员可以编辑模型厂商信息')
return
}
editingProvider.value = provider
//
form.value = { ...provider }
dialogVisible.value = true
}
@ -162,8 +152,14 @@ const handleAdd = () => {
ElMessage.warning('请先登录')
return
}
editingProvider.value = null
//
form.value = {
id: '',
name: '',
icon: ''
}
dialogVisible.value = true
form.value = { id: 1, name: '', icon: '' }
}
const submitForm = async () => {
@ -196,7 +192,7 @@ const submitForm = async () => {
}
dialogVisible.value = false
editingProvider.value = null
form.value = { id: 1, name: '', icon: '' }
form.value = { id: '', name: '', icon: '' }
} catch (error) {
console.error('Failed to submit provider:', error)
if (error.response?.data?.error) {