Add model type support for pricing management

This commit is contained in:
wood chen 2025-02-21 11:51:53 +08:00
parent d4aebb8148
commit 6fa37f6d6a
12 changed files with 384 additions and 29 deletions

View File

@ -21,9 +21,9 @@ jobs:
# 设置 Go 环境
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: '1.21'
go-version: '1.23'
# 构建后端(使用 Alpine 环境)
- name: Build backend
@ -31,12 +31,14 @@ jobs:
cd backend
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o main-amd64 .
GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o main-arm64 .
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o migrate-amd64 scripts/migrate.go
GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o migrate-arm64 scripts/migrate.go
# 设置 Node.js 环境
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '18'
node-version: '22'
cache: 'npm'
cache-dependency-path: frontend/package-lock.json

1
.gitignore vendored
View File

@ -30,3 +30,4 @@ out/
# 日志文件
*.log
logs/
backend/data/aimodels.db

View File

@ -24,6 +24,16 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \
rm main-* && \
chmod +x main
# 复制迁移工具
COPY backend/migrate-* ./
RUN if [ "$(uname -m)" = "aarch64" ]; then \
cp migrate-arm64 migrate; \
else \
cp migrate-amd64 migrate; \
fi && \
rm migrate-* && \
chmod +x migrate
COPY frontend/dist /app/frontend
COPY backend/config/nginx.conf /etc/nginx/nginx.conf
COPY scripts/start.sh ./

View File

@ -0,0 +1,61 @@
package handlers
import (
"database/sql"
"github.com/gin-gonic/gin"
)
// ModelType 模型类型结构
type ModelType struct {
Key string `json:"key"`
Label string `json:"label"`
}
// GetModelTypes 获取所有模型类型
func GetModelTypes(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
rows, err := db.Query("SELECT key, label FROM model_type")
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
defer rows.Close()
var types []ModelType
for rows.Next() {
var t ModelType
if err := rows.Scan(&t.Key, &t.Label); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
types = append(types, t)
}
c.JSON(200, types)
}
// CreateModelType 添加新的模型类型
func CreateModelType(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
var newType ModelType
if err := c.ShouldBindJSON(&newType); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
_, err := db.Exec(`
INSERT INTO model_type (key, label)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET label = excluded.label
`, newType.Key, newType.Label)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
c.JSON(201, newType)
}

View File

@ -50,9 +50,9 @@ func GetPrices(c *gin.Context) {
// 使用分页查询
query := `
SELECT id, model, billing_type, channel_type, currency, input_price, output_price,
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_at, updated_at, created_by,
temp_model, temp_billing_type, temp_channel_type, temp_currency,
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price`
if whereClause != "" {
@ -72,10 +72,10 @@ func GetPrices(c *gin.Context) {
for rows.Next() {
var price models.Price
if err := rows.Scan(
&price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency,
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
&price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
return
@ -107,10 +107,10 @@ func CreatePrice(c *gin.Context) {
now := time.Now()
result, err := db.Exec(`
INSERT INTO price (model, billing_type, channel_type, currency, input_price, output_price,
INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_by, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
price.Model, price.BillingType, price.ChannelType, price.Currency,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
now, now)
if err != nil {
@ -146,6 +146,7 @@ func UpdatePriceStatus(c *gin.Context) {
_, err := db.Exec(`
UPDATE price
SET model = COALESCE(temp_model, model),
model_type = COALESCE(temp_model_type, model_type),
billing_type = COALESCE(temp_billing_type, billing_type),
channel_type = COALESCE(temp_channel_type, channel_type),
currency = COALESCE(temp_currency, currency),
@ -155,6 +156,7 @@ func UpdatePriceStatus(c *gin.Context) {
status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
@ -174,6 +176,7 @@ func UpdatePriceStatus(c *gin.Context) {
SET status = ?,
updated_at = ?,
temp_model = NULL,
temp_model_type = NULL,
temp_billing_type = NULL,
temp_channel_type = NULL,
temp_currency = NULL,
@ -225,11 +228,11 @@ func UpdatePrice(c *gin.Context) {
// 将新的价格信息存储到临时字段
_, err = db.Exec(`
UPDATE price
SET temp_model = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?,
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?,
temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
updated_by = ?, updated_at = ?, status = 'pending'
WHERE id = ?`,
price.Model, price.BillingType, price.ChannelType, price.Currency,
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
price.InputPrice, price.OutputPrice, price.PriceSource,
currentUser.Username, now, id)
if err != nil {
@ -239,15 +242,15 @@ func UpdatePrice(c *gin.Context) {
// 获取更新后的价格信息
err = db.QueryRow(`
SELECT id, model, billing_type, channel_type, currency, input_price, output_price,
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
price_source, status, created_at, updated_at, created_by,
temp_model, temp_billing_type, temp_channel_type, temp_currency,
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
temp_input_price, temp_output_price, temp_price_source, updated_by
FROM price WHERE id = ?`, id).Scan(
&price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency,
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
&price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"})
@ -274,6 +277,7 @@ func DeletePrice(c *gin.Context) {
// PriceRate 价格倍率结构
type PriceRate struct {
Model string `json:"model"`
ModelType string `json:"model_type"`
Type string `json:"type"`
ChannelType uint `json:"channel_type"`
Input float64 `json:"input"`
@ -284,7 +288,7 @@ type PriceRate struct {
func GetPriceRates(c *gin.Context) {
db := c.MustGet("db").(*sql.DB)
rows, err := db.Query(`
SELECT model, billing_type, channel_type,
SELECT model, model_type, billing_type, channel_type,
CASE
WHEN currency = 'USD' THEN input_price / 2
ELSE input_price / 14
@ -307,6 +311,7 @@ func GetPriceRates(c *gin.Context) {
var rate PriceRate
if err := rows.Scan(
&rate.Model,
&rate.ModelType,
&rate.Type,
&rate.ChannelType,
&rate.Input,

View File

@ -88,6 +88,13 @@ func main() {
auth.GET("/user", handlers.GetUser)
auth.GET("/callback", handlers.AuthCallback)
}
// 模型类型相关路由
modelTypes := api.Group("/model-types")
{
modelTypes.GET("", handlers.GetModelTypes)
modelTypes.POST("", middleware.AuthRequired(), handlers.CreateModelType)
}
}
// 启动服务器

View File

@ -7,6 +7,7 @@ import (
type Price struct {
ID uint `json:"id"`
Model string `json:"model"`
ModelType string `json:"model_type"` // text2text, text2image, etc.
BillingType string `json:"billing_type"` // tokens or times
ChannelType string `json:"channel_type"`
Currency string `json:"currency"` // USD or CNY
@ -19,6 +20,7 @@ type Price struct {
CreatedBy string `json:"created_by"`
// 临时字段,用于存储待审核的更新
TempModel *string `json:"temp_model,omitempty"`
TempModelType *string `json:"temp_model_type,omitempty"`
TempBillingType *string `json:"temp_billing_type,omitempty"`
TempChannelType *string `json:"temp_channel_type,omitempty"`
TempCurrency *string `json:"temp_currency,omitempty"`
@ -34,6 +36,7 @@ func CreatePriceTableSQL() string {
CREATE TABLE IF NOT EXISTS price (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model TEXT NOT NULL,
model_type TEXT NOT NULL,
billing_type TEXT NOT NULL,
channel_type TEXT NOT NULL,
currency TEXT NOT NULL,
@ -45,6 +48,7 @@ func CreatePriceTableSQL() string {
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_by TEXT NOT NULL,
temp_model TEXT,
temp_model_type TEXT,
temp_billing_type TEXT,
temp_channel_type TEXT,
temp_currency TEXT,

174
backend/scripts/migrate.go Normal file
View File

@ -0,0 +1,174 @@
package main
import (
"database/sql"
"log"
"os"
"path/filepath"
"strings"
_ "modernc.org/sqlite"
)
// ModelType 模型类型结构
type ModelType struct {
Key string `json:"key"`
Label string `json:"label"`
}
func main() {
// 确保数据目录存在
dbDir := "./data"
if err := os.MkdirAll(dbDir, 0755); err != nil {
log.Fatalf("创建数据目录失败: %v", err)
}
// 连接数据库
dbPath := filepath.Join(dbDir, "aimodels.db")
db, err := sql.Open("sqlite", dbPath)
if err != nil {
log.Fatalf("连接数据库失败: %v", err)
}
defer db.Close()
// 创建model_type表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS model_type (
key TEXT PRIMARY KEY,
label TEXT NOT NULL
)
`)
if err != nil {
log.Fatalf("创建model_type表失败: %v", err)
}
// 初始化默认的模型类型
defaultTypes := []ModelType{
{Key: "text2text", Label: "文生文"},
{Key: "text2image", Label: "文生图"},
{Key: "text2speech", Label: "文生音"},
{Key: "speech2text", Label: "音生文"},
{Key: "image2text", Label: "图生文"},
{Key: "embedding", Label: "向量"},
{Key: "other", Label: "其他"},
}
// 插入默认类型
for _, t := range defaultTypes {
_, err = db.Exec(`
INSERT OR IGNORE INTO model_type (key, label)
VALUES (?, ?)
`, t.Key, t.Label)
if err != nil {
log.Printf("插入默认类型失败 %s: %v", t.Key, err)
}
}
// 检查model_type列是否存在
var hasModelType bool
err = db.QueryRow(`
SELECT COUNT(*) > 0
FROM pragma_table_info('price')
WHERE name = 'model_type'
`).Scan(&hasModelType)
if err != nil {
log.Fatalf("检查model_type列失败: %v", err)
}
// 如果model_type列不存在,则添加它
if !hasModelType {
log.Println("开始添加model_type列...")
// 开始事务
tx, err := db.Begin()
if err != nil {
log.Fatalf("开始事务失败: %v", err)
}
// 添加model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加model_type列失败: %v", err)
}
// 添加temp_model_type列
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`)
if err != nil {
tx.Rollback()
log.Fatalf("添加temp_model_type列失败: %v", err)
}
// 根据模型名称推断类型并更新
rows, err := tx.Query(`SELECT id, model FROM price`)
if err != nil {
tx.Rollback()
log.Fatalf("查询价格数据失败: %v", err)
}
defer rows.Close()
for rows.Next() {
var id int
var model string
if err := rows.Scan(&id, &model); err != nil {
tx.Rollback()
log.Fatalf("读取行数据失败: %v", err)
}
// 根据模型名称推断类型
modelType := inferModelType(model)
// 更新model_type
_, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id)
if err != nil {
tx.Rollback()
log.Fatalf("更新model_type失败: %v", err)
}
}
// 提交事务
if err := tx.Commit(); err != nil {
log.Fatalf("提交事务失败: %v", err)
}
log.Println("成功添加并更新model_type列")
} else {
log.Println("model_type列已存在,无需迁移")
}
}
// inferModelType 根据模型名称推断模型类型
func inferModelType(model string) string {
model = strings.ToLower(model)
switch {
case strings.Contains(model, "gpt") ||
strings.Contains(model, "llama") ||
strings.Contains(model, "claude") ||
strings.Contains(model, "palm") ||
strings.Contains(model, "gemini") ||
strings.Contains(model, "qwen") ||
strings.Contains(model, "chatglm"):
return "text2text"
case strings.Contains(model, "dall-e") ||
strings.Contains(model, "stable") ||
strings.Contains(model, "midjourney") ||
strings.Contains(model, "sd") ||
strings.Contains(model, "diffusion"):
return "text2image"
case strings.Contains(model, "whisper") ||
strings.Contains(model, "speech") ||
strings.Contains(model, "tts"):
return "text2speech"
case strings.Contains(model, "embedding") ||
strings.Contains(model, "ada") ||
strings.Contains(model, "text-embedding"):
return "embedding"
default:
return "other"
}
}

View File

@ -24,7 +24,7 @@
<h2>API文档</h2>
<el-collapse>
<el-collapse-item title="获取价格倍率">
<el-collapse-item title="One-Hub 价格倍率">
<div class="api-doc">
<div class="api-url">
<span class="method">GET</span>

View File

@ -68,6 +68,16 @@
</div>
</template>
</el-table-column>
<el-table-column label="模型类型" width="120">
<template #default="{ row }">
<div class="value-container">
<span>{{ getModelType(row.model_type) }}</span>
<el-tag v-if="row.temp_model_type" type="warning" size="small" effect="light">
待审核: {{ getModelType(row.temp_model_type) }}
</el-tag>
</div>
</template>
</el-table-column>
<el-table-column label="计费类型" width="120">
<template #default="{ row }">
<div class="value-container">
@ -125,16 +135,6 @@
</div>
</template>
</el-table-column>
<el-table-column label="输入倍率" width="120">
<template #default="{ row }">
{{ row.input_price === 0 ? '免费' : calculateRate(row.input_price, row.currency) }}
</template>
</el-table-column>
<el-table-column label="输出倍率" width="120">
<template #default="{ row }">
{{ row.output_price === 0 ? '免费' : calculateRate(row.output_price, row.currency) }}
</template>
</el-table-column>
<el-table-column width="80">
<template #default="{ row }">
<el-popover
@ -246,6 +246,24 @@ dall-e-3 按Token收费 OpenAI 美元 40.000000 40.000000"
<el-input v-model="row.model" placeholder="请输入模型名称" />
</template>
</el-table-column>
<el-table-column label="模型类型" width="120">
<template #default="{ row }">
<el-select
v-model="row.model_type"
placeholder="请选择或输入"
allow-create
filterable
@create="handleModelTypeCreate"
>
<el-option
v-for="(label, value) in modelTypeMap"
:key="value"
:label="label"
:value="value"
/>
</el-select>
</template>
</el-table-column>
<el-table-column label="计费类型" width="120">
<template #default="{ row }">
<el-select v-model="row.billing_type" placeholder="请选择">
@ -319,6 +337,24 @@ dall-e-3 按Token收费 OpenAI 美元 40.000000 40.000000"
<el-input v-model="form.model" />
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="模型类型">
<el-select
v-model="form.model_type"
placeholder="请选择或输入"
allow-create
filterable
@create="handleModelTypeCreate"
>
<el-option
v-for="(label, value) in modelTypeMap"
:key="value"
:label="label"
:value="value"
/>
</el-select>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item label="计费类型">
<el-select v-model="form.billing_type" placeholder="请选择">
@ -397,6 +433,7 @@ const prices = ref([])
const dialogVisible = ref(false)
const form = ref({
model: '',
model_type: '',
billing_type: 'tokens',
channel_type: '',
currency: 'USD',
@ -530,6 +567,7 @@ const handleAdd = () => {
editingPrice.value = null
form.value = {
model: '',
model_type: '',
billing_type: 'tokens',
channel_type: '',
currency: 'USD',
@ -597,6 +635,7 @@ const handleSubmitResponse = async (response) => {
editingPrice.value = null
form.value = {
model: '',
model_type: '',
billing_type: 'tokens',
channel_type: '',
currency: 'USD',
@ -633,9 +672,57 @@ const batchForms = ref([])
const selectedRows = ref([])
const batchSubmitting = ref(false)
//
const modelTypeMap = ref({})
//
const loadModelTypes = async () => {
try {
const response = await axios.get('/api/model-types')
const types = response.data
const map = {}
types.forEach(type => {
map[type.key] = type.label
})
modelTypeMap.value = map
} catch (error) {
console.error('Failed to load model types:', error)
ElMessage.error('加载模型类型失败')
}
}
//
const handleModelTypeCreate = async (value) => {
// key
const existingKey = Object.entries(modelTypeMap.value).find(([_, label]) => label === value)?.[0]
if (existingKey) {
return existingKey
}
// key使
let key = value
let label = value
if (!/^[a-zA-Z0-9_]+$/.test(value)) {
// key
key = `type_${Date.now()}`
label = value
}
try {
await axios.post('/api/model-types', { key, label })
modelTypeMap.value[key] = label
return key
} catch (error) {
console.error('Failed to create model type:', error)
ElMessage.error('创建模型类型失败')
return 'other'
}
}
//
const createNewRow = () => ({
model: '',
model_type: '',
billing_type: 'tokens',
channel_type: '',
currency: 'USD',
@ -839,6 +926,7 @@ watch(selectedProvider, () => {
})
onMounted(async () => {
await loadModelTypes()
await loadPrices()
})
</script>

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,9 @@
#!/bin/bash
# 执行数据库迁移
echo "执行数据库迁移..."
./migrate
# 启动后端服务
./main &