mirror of
https://github.com/woodchen-ink/aimodels-prices.git
synced 2025-07-18 13:41:59 +08:00
Add model type support for pricing management
This commit is contained in:
parent
d4aebb8148
commit
6fa37f6d6a
8
.github/workflows/docker-build.yml
vendored
8
.github/workflows/docker-build.yml
vendored
@ -21,9 +21,9 @@ jobs:
|
|||||||
|
|
||||||
# 设置 Go 环境
|
# 设置 Go 环境
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.21'
|
go-version: '1.23'
|
||||||
|
|
||||||
# 构建后端(使用 Alpine 环境)
|
# 构建后端(使用 Alpine 环境)
|
||||||
- name: Build backend
|
- name: Build backend
|
||||||
@ -31,12 +31,14 @@ jobs:
|
|||||||
cd backend
|
cd backend
|
||||||
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o main-amd64 .
|
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o main-amd64 .
|
||||||
GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o main-arm64 .
|
GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o main-arm64 .
|
||||||
|
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o migrate-amd64 scripts/migrate.go
|
||||||
|
GOOS=linux GOARCH=arm64 CGO_ENABLED=0 go build -o migrate-arm64 scripts/migrate.go
|
||||||
|
|
||||||
# 设置 Node.js 环境
|
# 设置 Node.js 环境
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: '18'
|
node-version: '22'
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: frontend/package-lock.json
|
cache-dependency-path: frontend/package-lock.json
|
||||||
|
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -30,3 +30,4 @@ out/
|
|||||||
# 日志文件
|
# 日志文件
|
||||||
*.log
|
*.log
|
||||||
logs/
|
logs/
|
||||||
|
backend/data/aimodels.db
|
||||||
|
10
Dockerfile
10
Dockerfile
@ -24,6 +24,16 @@ RUN if [ "$(uname -m)" = "aarch64" ]; then \
|
|||||||
rm main-* && \
|
rm main-* && \
|
||||||
chmod +x main
|
chmod +x main
|
||||||
|
|
||||||
|
# 复制迁移工具
|
||||||
|
COPY backend/migrate-* ./
|
||||||
|
RUN if [ "$(uname -m)" = "aarch64" ]; then \
|
||||||
|
cp migrate-arm64 migrate; \
|
||||||
|
else \
|
||||||
|
cp migrate-amd64 migrate; \
|
||||||
|
fi && \
|
||||||
|
rm migrate-* && \
|
||||||
|
chmod +x migrate
|
||||||
|
|
||||||
COPY frontend/dist /app/frontend
|
COPY frontend/dist /app/frontend
|
||||||
COPY backend/config/nginx.conf /etc/nginx/nginx.conf
|
COPY backend/config/nginx.conf /etc/nginx/nginx.conf
|
||||||
COPY scripts/start.sh ./
|
COPY scripts/start.sh ./
|
||||||
|
61
backend/handlers/model_type.go
Normal file
61
backend/handlers/model_type.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelType 模型类型结构
|
||||||
|
type ModelType struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelTypes 获取所有模型类型
|
||||||
|
func GetModelTypes(c *gin.Context) {
|
||||||
|
db := c.MustGet("db").(*sql.DB)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT key, label FROM model_type")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var types []ModelType
|
||||||
|
for rows.Next() {
|
||||||
|
var t ModelType
|
||||||
|
if err := rows.Scan(&t.Key, &t.Label); err != nil {
|
||||||
|
c.JSON(500, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
types = append(types, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, types)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateModelType 添加新的模型类型
|
||||||
|
func CreateModelType(c *gin.Context) {
|
||||||
|
db := c.MustGet("db").(*sql.DB)
|
||||||
|
|
||||||
|
var newType ModelType
|
||||||
|
if err := c.ShouldBindJSON(&newType); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO model_type (key, label)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(key) DO UPDATE SET label = excluded.label
|
||||||
|
`, newType.Key, newType.Label)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(201, newType)
|
||||||
|
}
|
@ -50,9 +50,9 @@ func GetPrices(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用分页查询
|
// 使用分页查询
|
||||||
query := `
|
query := `
|
||||||
SELECT id, model, billing_type, channel_type, currency, input_price, output_price,
|
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
||||||
price_source, status, created_at, updated_at, created_by,
|
price_source, status, created_at, updated_at, created_by,
|
||||||
temp_model, temp_billing_type, temp_channel_type, temp_currency,
|
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
|
||||||
temp_input_price, temp_output_price, temp_price_source, updated_by
|
temp_input_price, temp_output_price, temp_price_source, updated_by
|
||||||
FROM price`
|
FROM price`
|
||||||
if whereClause != "" {
|
if whereClause != "" {
|
||||||
@ -72,10 +72,10 @@ func GetPrices(c *gin.Context) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var price models.Price
|
var price models.Price
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency,
|
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
|
||||||
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
||||||
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
||||||
&price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
||||||
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
|
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to scan price"})
|
||||||
return
|
return
|
||||||
@ -107,10 +107,10 @@ func CreatePrice(c *gin.Context) {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
result, err := db.Exec(`
|
result, err := db.Exec(`
|
||||||
INSERT INTO price (model, billing_type, channel_type, currency, input_price, output_price,
|
INSERT INTO price (model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
||||||
price_source, status, created_by, created_at, updated_at)
|
price_source, status, created_by, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)`,
|
||||||
price.Model, price.BillingType, price.ChannelType, price.Currency,
|
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
|
||||||
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
|
price.InputPrice, price.OutputPrice, price.PriceSource, price.CreatedBy,
|
||||||
now, now)
|
now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -146,6 +146,7 @@ func UpdatePriceStatus(c *gin.Context) {
|
|||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
UPDATE price
|
UPDATE price
|
||||||
SET model = COALESCE(temp_model, model),
|
SET model = COALESCE(temp_model, model),
|
||||||
|
model_type = COALESCE(temp_model_type, model_type),
|
||||||
billing_type = COALESCE(temp_billing_type, billing_type),
|
billing_type = COALESCE(temp_billing_type, billing_type),
|
||||||
channel_type = COALESCE(temp_channel_type, channel_type),
|
channel_type = COALESCE(temp_channel_type, channel_type),
|
||||||
currency = COALESCE(temp_currency, currency),
|
currency = COALESCE(temp_currency, currency),
|
||||||
@ -155,6 +156,7 @@ func UpdatePriceStatus(c *gin.Context) {
|
|||||||
status = ?,
|
status = ?,
|
||||||
updated_at = ?,
|
updated_at = ?,
|
||||||
temp_model = NULL,
|
temp_model = NULL,
|
||||||
|
temp_model_type = NULL,
|
||||||
temp_billing_type = NULL,
|
temp_billing_type = NULL,
|
||||||
temp_channel_type = NULL,
|
temp_channel_type = NULL,
|
||||||
temp_currency = NULL,
|
temp_currency = NULL,
|
||||||
@ -174,6 +176,7 @@ func UpdatePriceStatus(c *gin.Context) {
|
|||||||
SET status = ?,
|
SET status = ?,
|
||||||
updated_at = ?,
|
updated_at = ?,
|
||||||
temp_model = NULL,
|
temp_model = NULL,
|
||||||
|
temp_model_type = NULL,
|
||||||
temp_billing_type = NULL,
|
temp_billing_type = NULL,
|
||||||
temp_channel_type = NULL,
|
temp_channel_type = NULL,
|
||||||
temp_currency = NULL,
|
temp_currency = NULL,
|
||||||
@ -225,11 +228,11 @@ func UpdatePrice(c *gin.Context) {
|
|||||||
// 将新的价格信息存储到临时字段
|
// 将新的价格信息存储到临时字段
|
||||||
_, err = db.Exec(`
|
_, err = db.Exec(`
|
||||||
UPDATE price
|
UPDATE price
|
||||||
SET temp_model = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?,
|
SET temp_model = ?, temp_model_type = ?, temp_billing_type = ?, temp_channel_type = ?, temp_currency = ?,
|
||||||
temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
|
temp_input_price = ?, temp_output_price = ?, temp_price_source = ?,
|
||||||
updated_by = ?, updated_at = ?, status = 'pending'
|
updated_by = ?, updated_at = ?, status = 'pending'
|
||||||
WHERE id = ?`,
|
WHERE id = ?`,
|
||||||
price.Model, price.BillingType, price.ChannelType, price.Currency,
|
price.Model, price.ModelType, price.BillingType, price.ChannelType, price.Currency,
|
||||||
price.InputPrice, price.OutputPrice, price.PriceSource,
|
price.InputPrice, price.OutputPrice, price.PriceSource,
|
||||||
currentUser.Username, now, id)
|
currentUser.Username, now, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -239,15 +242,15 @@ func UpdatePrice(c *gin.Context) {
|
|||||||
|
|
||||||
// 获取更新后的价格信息
|
// 获取更新后的价格信息
|
||||||
err = db.QueryRow(`
|
err = db.QueryRow(`
|
||||||
SELECT id, model, billing_type, channel_type, currency, input_price, output_price,
|
SELECT id, model, model_type, billing_type, channel_type, currency, input_price, output_price,
|
||||||
price_source, status, created_at, updated_at, created_by,
|
price_source, status, created_at, updated_at, created_by,
|
||||||
temp_model, temp_billing_type, temp_channel_type, temp_currency,
|
temp_model, temp_model_type, temp_billing_type, temp_channel_type, temp_currency,
|
||||||
temp_input_price, temp_output_price, temp_price_source, updated_by
|
temp_input_price, temp_output_price, temp_price_source, updated_by
|
||||||
FROM price WHERE id = ?`, id).Scan(
|
FROM price WHERE id = ?`, id).Scan(
|
||||||
&price.ID, &price.Model, &price.BillingType, &price.ChannelType, &price.Currency,
|
&price.ID, &price.Model, &price.ModelType, &price.BillingType, &price.ChannelType, &price.Currency,
|
||||||
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
&price.InputPrice, &price.OutputPrice, &price.PriceSource, &price.Status,
|
||||||
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
&price.CreatedAt, &price.UpdatedAt, &price.CreatedBy,
|
||||||
&price.TempModel, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
&price.TempModel, &price.TempModelType, &price.TempBillingType, &price.TempChannelType, &price.TempCurrency,
|
||||||
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
|
&price.TempInputPrice, &price.TempOutputPrice, &price.TempPriceSource, &price.UpdatedBy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get updated price"})
|
||||||
@ -274,6 +277,7 @@ func DeletePrice(c *gin.Context) {
|
|||||||
// PriceRate 价格倍率结构
|
// PriceRate 价格倍率结构
|
||||||
type PriceRate struct {
|
type PriceRate struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
ModelType string `json:"model_type"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
ChannelType uint `json:"channel_type"`
|
ChannelType uint `json:"channel_type"`
|
||||||
Input float64 `json:"input"`
|
Input float64 `json:"input"`
|
||||||
@ -284,7 +288,7 @@ type PriceRate struct {
|
|||||||
func GetPriceRates(c *gin.Context) {
|
func GetPriceRates(c *gin.Context) {
|
||||||
db := c.MustGet("db").(*sql.DB)
|
db := c.MustGet("db").(*sql.DB)
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
SELECT model, billing_type, channel_type,
|
SELECT model, model_type, billing_type, channel_type,
|
||||||
CASE
|
CASE
|
||||||
WHEN currency = 'USD' THEN input_price / 2
|
WHEN currency = 'USD' THEN input_price / 2
|
||||||
ELSE input_price / 14
|
ELSE input_price / 14
|
||||||
@ -307,6 +311,7 @@ func GetPriceRates(c *gin.Context) {
|
|||||||
var rate PriceRate
|
var rate PriceRate
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&rate.Model,
|
&rate.Model,
|
||||||
|
&rate.ModelType,
|
||||||
&rate.Type,
|
&rate.Type,
|
||||||
&rate.ChannelType,
|
&rate.ChannelType,
|
||||||
&rate.Input,
|
&rate.Input,
|
||||||
|
@ -88,6 +88,13 @@ func main() {
|
|||||||
auth.GET("/user", handlers.GetUser)
|
auth.GET("/user", handlers.GetUser)
|
||||||
auth.GET("/callback", handlers.AuthCallback)
|
auth.GET("/callback", handlers.AuthCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 模型类型相关路由
|
||||||
|
modelTypes := api.Group("/model-types")
|
||||||
|
{
|
||||||
|
modelTypes.GET("", handlers.GetModelTypes)
|
||||||
|
modelTypes.POST("", middleware.AuthRequired(), handlers.CreateModelType)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动服务器
|
// 启动服务器
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
type Price struct {
|
type Price struct {
|
||||||
ID uint `json:"id"`
|
ID uint `json:"id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
ModelType string `json:"model_type"` // text2text, text2image, etc.
|
||||||
BillingType string `json:"billing_type"` // tokens or times
|
BillingType string `json:"billing_type"` // tokens or times
|
||||||
ChannelType string `json:"channel_type"`
|
ChannelType string `json:"channel_type"`
|
||||||
Currency string `json:"currency"` // USD or CNY
|
Currency string `json:"currency"` // USD or CNY
|
||||||
@ -19,6 +20,7 @@ type Price struct {
|
|||||||
CreatedBy string `json:"created_by"`
|
CreatedBy string `json:"created_by"`
|
||||||
// 临时字段,用于存储待审核的更新
|
// 临时字段,用于存储待审核的更新
|
||||||
TempModel *string `json:"temp_model,omitempty"`
|
TempModel *string `json:"temp_model,omitempty"`
|
||||||
|
TempModelType *string `json:"temp_model_type,omitempty"`
|
||||||
TempBillingType *string `json:"temp_billing_type,omitempty"`
|
TempBillingType *string `json:"temp_billing_type,omitempty"`
|
||||||
TempChannelType *string `json:"temp_channel_type,omitempty"`
|
TempChannelType *string `json:"temp_channel_type,omitempty"`
|
||||||
TempCurrency *string `json:"temp_currency,omitempty"`
|
TempCurrency *string `json:"temp_currency,omitempty"`
|
||||||
@ -34,6 +36,7 @@ func CreatePriceTableSQL() string {
|
|||||||
CREATE TABLE IF NOT EXISTS price (
|
CREATE TABLE IF NOT EXISTS price (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
model TEXT NOT NULL,
|
model TEXT NOT NULL,
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
billing_type TEXT NOT NULL,
|
billing_type TEXT NOT NULL,
|
||||||
channel_type TEXT NOT NULL,
|
channel_type TEXT NOT NULL,
|
||||||
currency TEXT NOT NULL,
|
currency TEXT NOT NULL,
|
||||||
@ -45,6 +48,7 @@ func CreatePriceTableSQL() string {
|
|||||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
created_by TEXT NOT NULL,
|
created_by TEXT NOT NULL,
|
||||||
temp_model TEXT,
|
temp_model TEXT,
|
||||||
|
temp_model_type TEXT,
|
||||||
temp_billing_type TEXT,
|
temp_billing_type TEXT,
|
||||||
temp_channel_type TEXT,
|
temp_channel_type TEXT,
|
||||||
temp_currency TEXT,
|
temp_currency TEXT,
|
||||||
|
174
backend/scripts/migrate.go
Normal file
174
backend/scripts/migrate.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelType 模型类型结构
|
||||||
|
type ModelType struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 确保数据目录存在
|
||||||
|
dbDir := "./data"
|
||||||
|
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||||
|
log.Fatalf("创建数据目录失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接数据库
|
||||||
|
dbPath := filepath.Join(dbDir, "aimodels.db")
|
||||||
|
db, err := sql.Open("sqlite", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("连接数据库失败: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// 创建model_type表
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS model_type (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
label TEXT NOT NULL
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("创建model_type表失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化默认的模型类型
|
||||||
|
defaultTypes := []ModelType{
|
||||||
|
{Key: "text2text", Label: "文生文"},
|
||||||
|
{Key: "text2image", Label: "文生图"},
|
||||||
|
{Key: "text2speech", Label: "文生音"},
|
||||||
|
{Key: "speech2text", Label: "音生文"},
|
||||||
|
{Key: "image2text", Label: "图生文"},
|
||||||
|
{Key: "embedding", Label: "向量"},
|
||||||
|
{Key: "other", Label: "其他"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 插入默认类型
|
||||||
|
for _, t := range defaultTypes {
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT OR IGNORE INTO model_type (key, label)
|
||||||
|
VALUES (?, ?)
|
||||||
|
`, t.Key, t.Label)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("插入默认类型失败 %s: %v", t.Key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查model_type列是否存在
|
||||||
|
var hasModelType bool
|
||||||
|
err = db.QueryRow(`
|
||||||
|
SELECT COUNT(*) > 0
|
||||||
|
FROM pragma_table_info('price')
|
||||||
|
WHERE name = 'model_type'
|
||||||
|
`).Scan(&hasModelType)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("检查model_type列失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果model_type列不存在,则添加它
|
||||||
|
if !hasModelType {
|
||||||
|
log.Println("开始添加model_type列...")
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("开始事务失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加model_type列
|
||||||
|
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN model_type TEXT`)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
log.Fatalf("添加model_type列失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加temp_model_type列
|
||||||
|
_, err = tx.Exec(`ALTER TABLE price ADD COLUMN temp_model_type TEXT`)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
log.Fatalf("添加temp_model_type列失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据模型名称推断类型并更新
|
||||||
|
rows, err := tx.Query(`SELECT id, model FROM price`)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
log.Fatalf("查询价格数据失败: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var id int
|
||||||
|
var model string
|
||||||
|
if err := rows.Scan(&id, &model); err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
log.Fatalf("读取行数据失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据模型名称推断类型
|
||||||
|
modelType := inferModelType(model)
|
||||||
|
|
||||||
|
// 更新model_type
|
||||||
|
_, err = tx.Exec(`UPDATE price SET model_type = ? WHERE id = ?`, modelType, id)
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
log.Fatalf("更新model_type失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
log.Fatalf("提交事务失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("成功添加并更新model_type列")
|
||||||
|
} else {
|
||||||
|
log.Println("model_type列已存在,无需迁移")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferModelType 根据模型名称推断模型类型
|
||||||
|
func inferModelType(model string) string {
|
||||||
|
model = strings.ToLower(model)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(model, "gpt") ||
|
||||||
|
strings.Contains(model, "llama") ||
|
||||||
|
strings.Contains(model, "claude") ||
|
||||||
|
strings.Contains(model, "palm") ||
|
||||||
|
strings.Contains(model, "gemini") ||
|
||||||
|
strings.Contains(model, "qwen") ||
|
||||||
|
strings.Contains(model, "chatglm"):
|
||||||
|
return "text2text"
|
||||||
|
|
||||||
|
case strings.Contains(model, "dall-e") ||
|
||||||
|
strings.Contains(model, "stable") ||
|
||||||
|
strings.Contains(model, "midjourney") ||
|
||||||
|
strings.Contains(model, "sd") ||
|
||||||
|
strings.Contains(model, "diffusion"):
|
||||||
|
return "text2image"
|
||||||
|
|
||||||
|
case strings.Contains(model, "whisper") ||
|
||||||
|
strings.Contains(model, "speech") ||
|
||||||
|
strings.Contains(model, "tts"):
|
||||||
|
return "text2speech"
|
||||||
|
|
||||||
|
case strings.Contains(model, "embedding") ||
|
||||||
|
strings.Contains(model, "ada") ||
|
||||||
|
strings.Contains(model, "text-embedding"):
|
||||||
|
return "embedding"
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "other"
|
||||||
|
}
|
||||||
|
}
|
@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
<h2>API文档</h2>
|
<h2>API文档</h2>
|
||||||
<el-collapse>
|
<el-collapse>
|
||||||
<el-collapse-item title="获取价格倍率">
|
<el-collapse-item title="One-Hub 价格倍率">
|
||||||
<div class="api-doc">
|
<div class="api-doc">
|
||||||
<div class="api-url">
|
<div class="api-url">
|
||||||
<span class="method">GET</span>
|
<span class="method">GET</span>
|
||||||
|
@ -68,6 +68,16 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
|
<el-table-column label="模型类型" width="120">
|
||||||
|
<template #default="{ row }">
|
||||||
|
<div class="value-container">
|
||||||
|
<span>{{ getModelType(row.model_type) }}</span>
|
||||||
|
<el-tag v-if="row.temp_model_type" type="warning" size="small" effect="light">
|
||||||
|
待审核: {{ getModelType(row.temp_model_type) }}
|
||||||
|
</el-tag>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</el-table-column>
|
||||||
<el-table-column label="计费类型" width="120">
|
<el-table-column label="计费类型" width="120">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
<div class="value-container">
|
<div class="value-container">
|
||||||
@ -125,16 +135,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
<el-table-column label="输入倍率" width="120">
|
|
||||||
<template #default="{ row }">
|
|
||||||
{{ row.input_price === 0 ? '免费' : calculateRate(row.input_price, row.currency) }}
|
|
||||||
</template>
|
|
||||||
</el-table-column>
|
|
||||||
<el-table-column label="输出倍率" width="120">
|
|
||||||
<template #default="{ row }">
|
|
||||||
{{ row.output_price === 0 ? '免费' : calculateRate(row.output_price, row.currency) }}
|
|
||||||
</template>
|
|
||||||
</el-table-column>
|
|
||||||
<el-table-column width="80">
|
<el-table-column width="80">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
<el-popover
|
<el-popover
|
||||||
@ -246,6 +246,24 @@ dall-e-3 按Token收费 OpenAI 美元 40.000000 40.000000"
|
|||||||
<el-input v-model="row.model" placeholder="请输入模型名称" />
|
<el-input v-model="row.model" placeholder="请输入模型名称" />
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
|
<el-table-column label="模型类型" width="120">
|
||||||
|
<template #default="{ row }">
|
||||||
|
<el-select
|
||||||
|
v-model="row.model_type"
|
||||||
|
placeholder="请选择或输入"
|
||||||
|
allow-create
|
||||||
|
filterable
|
||||||
|
@create="handleModelTypeCreate"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="(label, value) in modelTypeMap"
|
||||||
|
:key="value"
|
||||||
|
:label="label"
|
||||||
|
:value="value"
|
||||||
|
/>
|
||||||
|
</el-select>
|
||||||
|
</template>
|
||||||
|
</el-table-column>
|
||||||
<el-table-column label="计费类型" width="120">
|
<el-table-column label="计费类型" width="120">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
<el-select v-model="row.billing_type" placeholder="请选择">
|
<el-select v-model="row.billing_type" placeholder="请选择">
|
||||||
@ -319,6 +337,24 @@ dall-e-3 按Token收费 OpenAI 美元 40.000000 40.000000"
|
|||||||
<el-input v-model="form.model" />
|
<el-input v-model="form.model" />
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
</el-col>
|
</el-col>
|
||||||
|
<el-col :span="12">
|
||||||
|
<el-form-item label="模型类型">
|
||||||
|
<el-select
|
||||||
|
v-model="form.model_type"
|
||||||
|
placeholder="请选择或输入"
|
||||||
|
allow-create
|
||||||
|
filterable
|
||||||
|
@create="handleModelTypeCreate"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="(label, value) in modelTypeMap"
|
||||||
|
:key="value"
|
||||||
|
:label="label"
|
||||||
|
:value="value"
|
||||||
|
/>
|
||||||
|
</el-select>
|
||||||
|
</el-form-item>
|
||||||
|
</el-col>
|
||||||
<el-col :span="12">
|
<el-col :span="12">
|
||||||
<el-form-item label="计费类型">
|
<el-form-item label="计费类型">
|
||||||
<el-select v-model="form.billing_type" placeholder="请选择">
|
<el-select v-model="form.billing_type" placeholder="请选择">
|
||||||
@ -397,6 +433,7 @@ const prices = ref([])
|
|||||||
const dialogVisible = ref(false)
|
const dialogVisible = ref(false)
|
||||||
const form = ref({
|
const form = ref({
|
||||||
model: '',
|
model: '',
|
||||||
|
model_type: '',
|
||||||
billing_type: 'tokens',
|
billing_type: 'tokens',
|
||||||
channel_type: '',
|
channel_type: '',
|
||||||
currency: 'USD',
|
currency: 'USD',
|
||||||
@ -530,6 +567,7 @@ const handleAdd = () => {
|
|||||||
editingPrice.value = null
|
editingPrice.value = null
|
||||||
form.value = {
|
form.value = {
|
||||||
model: '',
|
model: '',
|
||||||
|
model_type: '',
|
||||||
billing_type: 'tokens',
|
billing_type: 'tokens',
|
||||||
channel_type: '',
|
channel_type: '',
|
||||||
currency: 'USD',
|
currency: 'USD',
|
||||||
@ -597,6 +635,7 @@ const handleSubmitResponse = async (response) => {
|
|||||||
editingPrice.value = null
|
editingPrice.value = null
|
||||||
form.value = {
|
form.value = {
|
||||||
model: '',
|
model: '',
|
||||||
|
model_type: '',
|
||||||
billing_type: 'tokens',
|
billing_type: 'tokens',
|
||||||
channel_type: '',
|
channel_type: '',
|
||||||
currency: 'USD',
|
currency: 'USD',
|
||||||
@ -633,9 +672,57 @@ const batchForms = ref([])
|
|||||||
const selectedRows = ref([])
|
const selectedRows = ref([])
|
||||||
const batchSubmitting = ref(false)
|
const batchSubmitting = ref(false)
|
||||||
|
|
||||||
|
// 添加模型类型映射
|
||||||
|
const modelTypeMap = ref({})
|
||||||
|
|
||||||
|
// 加载模型类型
|
||||||
|
const loadModelTypes = async () => {
|
||||||
|
try {
|
||||||
|
const response = await axios.get('/api/model-types')
|
||||||
|
const types = response.data
|
||||||
|
const map = {}
|
||||||
|
types.forEach(type => {
|
||||||
|
map[type.key] = type.label
|
||||||
|
})
|
||||||
|
modelTypeMap.value = map
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to load model types:', error)
|
||||||
|
ElMessage.error('加载模型类型失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理新增的模型类型
|
||||||
|
const handleModelTypeCreate = async (value) => {
|
||||||
|
// 如果输入的是中文描述,尝试查找对应的key
|
||||||
|
const existingKey = Object.entries(modelTypeMap.value).find(([_, label]) => label === value)?.[0]
|
||||||
|
if (existingKey) {
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果输入的是英文key,直接使用
|
||||||
|
let key = value
|
||||||
|
let label = value
|
||||||
|
if (!/^[a-zA-Z0-9_]+$/.test(value)) {
|
||||||
|
// 如果是中文描述,生成一个新的key
|
||||||
|
key = `type_${Date.now()}`
|
||||||
|
label = value
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await axios.post('/api/model-types', { key, label })
|
||||||
|
modelTypeMap.value[key] = label
|
||||||
|
return key
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to create model type:', error)
|
||||||
|
ElMessage.error('创建模型类型失败')
|
||||||
|
return 'other'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 创建新行的默认数据
|
// 创建新行的默认数据
|
||||||
const createNewRow = () => ({
|
const createNewRow = () => ({
|
||||||
model: '',
|
model: '',
|
||||||
|
model_type: '',
|
||||||
billing_type: 'tokens',
|
billing_type: 'tokens',
|
||||||
channel_type: '',
|
channel_type: '',
|
||||||
currency: 'USD',
|
currency: 'USD',
|
||||||
@ -839,6 +926,7 @@ watch(selectedProvider, () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
|
await loadModelTypes()
|
||||||
await loadPrices()
|
await loadPrices()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,5 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 执行数据库迁移
|
||||||
|
echo "执行数据库迁移..."
|
||||||
|
./migrate
|
||||||
|
|
||||||
# 启动后端服务
|
# 启动后端服务
|
||||||
./main &
|
./main &
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user