数据库优化与提示词服务增强

- 重构全局变量以支持新的数据库文件路径和调试模式设置。
- 直接使用常量代替硬编码路径,以增强代码的可维护性。
- 引入环境变量读取调试模式,提高系统的灵活性和可配置性。
- 数据库缓存机制调整,通过更细粒度的缓存管理提高性能。
- 提示词服务现在依赖于数据库存储,以保证数据的持久化。
- 更新提示词操作现在通过专门的数据库方法执行,使得代码更加模块化和清晰。
- 错误处理和日志记录在提示词操作中得到改进,提高系统的健壮性和可追踪性。
This commit is contained in:
wood chen 2024-09-18 15:40:25 +08:00
parent 6b4d776a89
commit b12915ef4f
5 changed files with 151 additions and 48 deletions

View File

@ -1,13 +1,26 @@
package core package core
import (
"os"
"path/filepath"
)
var ( var (
BOT_TOKEN string BOT_TOKEN string
ADMIN_ID int64 ADMIN_ID int64
DB_FILE string
DEBUG_MODE bool
) )
func InitGlobalVariables(botToken string, adminID int64) { func InitGlobalVariables(botToken string, adminID int64) {
BOT_TOKEN = botToken BOT_TOKEN = botToken
ADMIN_ID = adminID ADMIN_ID = adminID
// 设置数据库文件路径
DB_FILE = filepath.Join("/app/data", "q58.db")
// 从环境变量中读取调试模式设置
DEBUG_MODE = os.Getenv("DEBUG_MODE") == "true"
} }
func IsAdmin(userID int64) bool { func IsAdmin(userID int64) bool {

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
@ -11,24 +12,25 @@ import (
) )
type Database struct { type Database struct {
db *sql.DB db *sql.DB
dbFile string keywordsCache []string
keywordsCache []string whitelistCache []string
whitelistCache []string promptRepliesCache map[string]string
cacheTime time.Time keywordsCacheTime time.Time
mu sync.Mutex whitelistCacheTime time.Time
promptRepliesCacheTime time.Time
mu sync.Mutex
} }
func NewDatabase(dbFile string) (*Database, error) { func NewDatabase() (*Database, error) {
os.MkdirAll(filepath.Dir(dbFile), os.ModePerm) os.MkdirAll(filepath.Dir(DB_FILE), os.ModePerm)
db, err := sql.Open("sqlite", dbFile) db, err := sql.Open("sqlite", DB_FILE)
if err != nil { if err != nil {
return nil, err return nil, err
} }
database := &Database{ database := &Database{
db: db, db: db,
dbFile: dbFile,
} }
if err := database.createTables(); err != nil { if err := database.createTables(); err != nil {
@ -46,6 +48,8 @@ func (d *Database) createTables() error {
`CREATE TABLE IF NOT EXISTS whitelist `CREATE TABLE IF NOT EXISTS whitelist
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)`, (id INTEGER PRIMARY KEY, domain TEXT UNIQUE)`,
`CREATE INDEX IF NOT EXISTS idx_domain ON whitelist(domain)`, `CREATE INDEX IF NOT EXISTS idx_domain ON whitelist(domain)`,
`CREATE TABLE IF NOT EXISTS prompt_replies
(prompt TEXT PRIMARY KEY, reply TEXT NOT NULL)`,
} }
for _, query := range queries { for _, query := range queries {
@ -82,7 +86,7 @@ func (d *Database) AddKeyword(keyword string) error {
if err != nil { if err != nil {
return err return err
} }
d.invalidateCache() d.invalidateCache("keywords")
return nil return nil
} }
@ -91,7 +95,7 @@ func (d *Database) RemoveKeyword(keyword string) error {
if err != nil { if err != nil {
return err return err
} }
d.invalidateCache() d.invalidateCache("keywords")
return nil return nil
} }
@ -99,13 +103,13 @@ func (d *Database) GetAllKeywords() ([]string, error) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.keywordsCache == nil || time.Since(d.cacheTime) > 5*time.Minute { if d.keywordsCache == nil || time.Since(d.keywordsCacheTime) > 5*time.Minute {
keywords, err := d.executeQuery("SELECT keyword FROM keywords") keywords, err := d.executeQuery("SELECT keyword FROM keywords")
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.keywordsCache = keywords d.keywordsCache = keywords
d.cacheTime = time.Now() d.keywordsCacheTime = time.Now()
} }
return d.keywordsCache, nil return d.keywordsCache, nil
@ -134,7 +138,7 @@ func (d *Database) RemoveKeywordsContaining(substring string) ([]string, error)
return nil, err return nil, err
} }
d.invalidateCache() d.invalidateCache("keywords")
return removedKeywords, nil return removedKeywords, nil
} }
@ -143,7 +147,7 @@ func (d *Database) AddWhitelist(domain string) error {
if err != nil { if err != nil {
return err return err
} }
d.invalidateCache() d.invalidateCache("whitelist")
return nil return nil
} }
@ -152,7 +156,7 @@ func (d *Database) RemoveWhitelist(domain string) error {
if err != nil { if err != nil {
return err return err
} }
d.invalidateCache() d.invalidateCache("whitelist")
return nil return nil
} }
@ -160,13 +164,13 @@ func (d *Database) GetAllWhitelist() ([]string, error) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
if d.whitelistCache == nil || time.Since(d.cacheTime) > 5*time.Minute { if d.whitelistCache == nil || time.Since(d.whitelistCacheTime) > 5*time.Minute {
whitelist, err := d.executeQuery("SELECT domain FROM whitelist") whitelist, err := d.executeQuery("SELECT domain FROM whitelist")
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.whitelistCache = whitelist d.whitelistCache = whitelist
d.cacheTime = time.Now() d.whitelistCacheTime = time.Now()
} }
return d.whitelistCache, nil return d.whitelistCache, nil
@ -194,12 +198,85 @@ func (d *Database) WhitelistExists(domain string) (bool, error) {
return count > 0, nil return count > 0, nil
} }
func (d *Database) invalidateCache() { func (d *Database) AddPromptReply(prompt, reply string) error {
_, err := d.db.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply)
if err != nil {
return err
}
d.invalidateCache("promptReplies")
return nil
}
func (d *Database) DeletePromptReply(prompt string) error {
_, err := d.db.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt))
if err != nil {
return err
}
d.invalidateCache("promptReplies")
return nil
}
func (d *Database) fetchAllPromptReplies() (map[string]string, error) {
rows, err := d.db.Query("SELECT prompt, reply FROM prompt_replies")
if err != nil {
return nil, err
}
defer rows.Close()
promptReplies := make(map[string]string)
for rows.Next() {
var prompt, reply string
if err := rows.Scan(&prompt, &reply); err != nil {
return nil, err
}
promptReplies[prompt] = reply
}
return promptReplies, nil
}
func (d *Database) GetAllPromptReplies() (map[string]string, error) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
d.keywordsCache = nil
d.whitelistCache = nil if d.promptRepliesCache == nil || time.Since(d.promptRepliesCacheTime) > 5*time.Minute {
d.cacheTime = time.Time{} promptReplies, err := d.fetchAllPromptReplies()
if err != nil {
return nil, err
}
d.promptRepliesCache = promptReplies
d.promptRepliesCacheTime = time.Now()
}
// 返回一个副本以防止外部修改缓存
result := make(map[string]string, len(d.promptRepliesCache))
for k, v := range d.promptRepliesCache {
result[k] = v
}
return result, nil
}
func (d *Database) invalidateCache(cacheType string) {
d.mu.Lock()
defer d.mu.Unlock()
switch cacheType {
case "keywords":
d.keywordsCache = nil
d.keywordsCacheTime = time.Time{}
case "whitelist":
d.whitelistCache = nil
d.whitelistCacheTime = time.Time{}
case "promptReplies":
d.promptRepliesCache = nil
d.promptRepliesCacheTime = time.Time{}
default:
// 清除所有缓存
d.keywordsCache = nil
d.whitelistCache = nil
d.promptRepliesCache = nil
d.keywordsCacheTime = time.Time{}
d.whitelistCacheTime = time.Time{}
d.promptRepliesCacheTime = time.Time{}
}
} }
func (d *Database) Close() error { func (d *Database) Close() error {

View File

@ -82,7 +82,7 @@ func startBot() error {
return fmt.Errorf("failed to create bot: %w", err) return fmt.Errorf("failed to create bot: %w", err)
} }
bot.Debug = debugMode bot.Debug = core.DEBUG_MODE
log.Printf("Authorized on account %s", bot.Self.UserName) log.Printf("Authorized on account %s", bot.Self.UserName)
@ -91,7 +91,7 @@ func startBot() error {
return fmt.Errorf("error registering commands: %w", err) return fmt.Errorf("error registering commands: %w", err)
} }
linkFilter, err := NewLinkFilter(dbFile) linkFilter, err := NewLinkFilter(core.DB_FILE)
if err != nil { if err != nil {
return fmt.Errorf("failed to create LinkFilter: %v", err) return fmt.Errorf("failed to create LinkFilter: %v", err)
} }

View File

@ -1,21 +1,13 @@
package service package service
import ( import (
"os"
"time" "time"
"github.com/woodchen-ink/Q58Bot/core" "github.com/woodchen-ink/Q58Bot/core"
) )
var (
dbFile string
debugMode bool
)
func Init(botToken string, adminID int64) { func Init(botToken string, adminID int64) {
core.InitGlobalVariables(botToken, adminID) core.InitGlobalVariables(botToken, adminID)
dbFile = "/app/data/q58.db"
debugMode = os.Getenv("DEBUG_MODE") == "true"
// 设置时区 // 设置时区
loc := time.FixedZone("Asia/Singapore", 8*60*60) loc := time.FixedZone("Asia/Singapore", 8*60*60)

View File

@ -2,6 +2,7 @@ package service
import ( import (
"fmt" "fmt"
"log"
"strings" "strings"
"sync" "sync"
@ -13,23 +14,35 @@ import (
var ( var (
promptReplies = make(map[string]string) promptReplies = make(map[string]string)
promptMutex sync.RWMutex promptMutex sync.RWMutex
db *core.Database
) )
func SetPromptReply(prompt, reply string) { func InitPromptService(database *core.Database) error {
promptMutex.Lock() db = database
defer promptMutex.Unlock() return loadPromptRepliesFromDB()
promptReplies[strings.ToLower(prompt)] = reply
} }
func DeletePromptReply(prompt string) { func loadPromptRepliesFromDB() error {
promptMutex.Lock() var err error
defer promptMutex.Unlock() promptReplies, err = db.GetAllPromptReplies()
delete(promptReplies, strings.ToLower(prompt)) return err
}
func SetPromptReply(prompt, reply string) error {
return db.AddPromptReply(prompt, reply)
}
func DeletePromptReply(prompt string) error {
return db.DeletePromptReply(prompt)
} }
func GetPromptReply(message string) (string, bool) { func GetPromptReply(message string) (string, bool) {
promptMutex.RLock() promptReplies, err := db.GetAllPromptReplies()
defer promptMutex.RUnlock() if err != nil {
log.Printf("Error getting prompt replies: %v", err)
return "", false
}
for prompt, reply := range promptReplies { for prompt, reply := range promptReplies {
if strings.Contains(strings.ToLower(message), prompt) { if strings.Contains(strings.ToLower(message), prompt) {
return reply, true return reply, true
@ -77,14 +90,22 @@ func HandlePromptCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message) {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "请同时提供提示词和回复。")) bot.Send(tgbotapi.NewMessage(message.Chat.ID, "请同时提供提示词和回复。"))
return return
} }
SetPromptReply(promptAndReply[0], promptAndReply[1]) err := SetPromptReply(promptAndReply[0], promptAndReply[1])
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("设置提示词失败:%v", err)))
return
}
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已设置提示词 '%s' 的回复。", promptAndReply[0]))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已设置提示词 '%s' 的回复。", promptAndReply[0])))
case "delete": case "delete":
if len(args) < 3 { if len(args) < 3 {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "使用方法: /prompt delete <提示词>")) bot.Send(tgbotapi.NewMessage(message.Chat.ID, "使用方法: /prompt delete <提示词>"))
return return
} }
DeletePromptReply(args[2]) err := DeletePromptReply(args[2])
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("删除提示词失败:%v", err)))
return
}
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已删除提示词 '%s' 的回复。", args[2]))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("已删除提示词 '%s' 的回复。", args[2])))
case "list": case "list":
bot.Send(tgbotapi.NewMessage(message.Chat.ID, ListPromptReplies())) bot.Send(tgbotapi.NewMessage(message.Chat.ID, ListPromptReplies()))