数据库和核心初始化优化

- 数据库:
  - 从Database实例中移除AddKeyword和RemoveKeyword中的错误返回值。现在只在发生错误时返回错误。
  - 为RemoveKeyword添加受影响行数的返回值,以判断关键词是否被成功删除。
  - 优化AddPromptReply和DeletePromptReply,使用事务确保数据的一致性和完整性。
  - 调整GetAllPromptReplies以强制刷新缓存并更新缓存时间。

- 核心:
  - 重构init.go中的全局变量初始化,移除多余注释。
  - 在main.go中添加数据库关闭操作,确保资源在程序结束时被正确释放。

- 链接过滤器:
  - 重构LinkFilter服务,移除数据库实例字段。
  - 更新LinkFilter中的数据加载和关键词操作,使用core包中的数据库方法。
  - 添加LinkFilter的Close方法以关闭数据库连接。

- 消息处理器:
  - 移除message_handler.go中handleUpdate和handleAdminCommand中的数据库参数。
  - 更新RunMessageHandler以初始化数据库并确保在结束时关闭连接。
  - 调整handleListKeywords、handleAddKeyword、handleDeleteKeyword、handleDeleteContainingKeyword、handleListWhitelist、handleAddWhitelist和handleDeleteWhitelist,移除数据库参数。

- 提示回复:
  - 在prompt_reply.go中移除全局数据库变量。
  - 更新SetPromptReply和DeletePromptReply,使用core.DB代替db。
  - 调整GetPromptReply和ListPromptReplies使用更新后的数据库访问方法。

这些更改优化了代码结构,减少了全局状态,并提高了数据库操作的可靠性。
This commit is contained in:
wood chen 2024-09-19 23:07:11 +08:00
parent f27f87b708
commit b153581254
6 changed files with 127 additions and 112 deletions

View File

@ -91,13 +91,17 @@ func (d *Database) AddKeyword(keyword string) error {
return nil
}
func (d *Database) RemoveKeyword(keyword string) error {
_, err := d.db.Exec("DELETE FROM keywords WHERE keyword = ?", keyword)
func (d *Database) RemoveKeyword(keyword string) (bool, error) {
result, err := d.db.Exec("DELETE FROM keywords WHERE keyword = ?", keyword)
if err != nil {
return err
return false, err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return false, err
}
d.invalidateCache("keywords")
return nil
return rowsAffected > 0, nil
}
func (d *Database) GetAllKeywords() ([]string, error) {
@ -200,19 +204,41 @@ func (d *Database) WhitelistExists(domain string) (bool, error) {
}
func (d *Database) AddPromptReply(prompt, reply string) error {
_, err := d.db.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply)
tx, err := d.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply)
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
d.invalidateCache("promptReplies")
return nil
}
func (d *Database) DeletePromptReply(prompt string) error {
_, err := d.db.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt))
tx, err := d.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt))
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
d.invalidateCache("promptReplies")
return nil
}
@ -239,16 +265,15 @@ func (d *Database) GetAllPromptReplies() (map[string]string, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.promptRepliesCache == nil || time.Since(d.promptRepliesCacheTime) > 5*time.Minute {
promptReplies, err := d.fetchAllPromptReplies()
if err != nil {
return nil, err
}
d.promptRepliesCache = promptReplies
d.promptRepliesCacheTime = time.Now()
// 强制刷新缓存
promptReplies, err := d.fetchAllPromptReplies()
if err != nil {
return nil, err
}
d.promptRepliesCache = promptReplies
d.promptRepliesCacheTime = time.Now()
// 返回一个副本以防止外部修改缓存
// 返回一个副本
result := make(map[string]string, len(d.promptRepliesCache))
for k, v := range d.promptRepliesCache {
result[k] = v

View File

@ -22,6 +22,7 @@ var (
DB_FILE string
DEBUG_MODE bool
err error
DB *Database
)
func IsAdmin(userID int64) bool {
@ -52,20 +53,12 @@ func Init() error {
adminIDStr := os.Getenv("ADMIN_ID")
ADMIN_ID, err = mustParseInt64(adminIDStr)
if err != nil {
return fmt.Errorf("Invalid ADMIN_ID: %v", err)
return fmt.Errorf("invalid ADMIN_ID: %v", err)
}
// 初始化 Bot API
Bot, err = tgbotapi.NewBotAPI(BOT_TOKEN)
if err != nil {
return fmt.Errorf("创建 Bot API 失败: %v", err)
}
log.Printf("账户已授权 %s", Bot.Self.UserName)
// 初始化数据库
DB_FILE = filepath.Join("/app/data", "q58.db")
_, err = NewDatabase()
DB, err = NewDatabase()
if err != nil {
return fmt.Errorf("初始化数据库失败: %v", err)
}
@ -81,7 +74,7 @@ func Init() error {
chatIDStr := os.Getenv("CHAT_ID")
ChatID, err = mustParseInt64(chatIDStr)
if err != nil {
return fmt.Errorf("Invalid CHAT_ID: %v", err)
return fmt.Errorf("invalid CHAT_ID: %v", err)
}
// 初始化 Symbols

View File

@ -15,6 +15,7 @@ func main() {
if err != nil {
log.Fatalf("Failed to initialize service: %v", err)
}
defer core.DB.Close() // 确保在程序退出时关闭数据库连接
go binance.RunBinance()

View File

@ -14,7 +14,6 @@ import (
var logger = log.New(log.Writer(), "LinkFilter: ", log.Ldate|log.Ltime|log.Lshortfile)
type LinkFilter struct {
db *core.Database
keywords []string
whitelist []string
linkPattern *regexp.Regexp
@ -22,18 +21,11 @@ type LinkFilter struct {
}
func NewLinkFilter() (*LinkFilter, error) {
db, err := core.NewDatabase()
if err != nil {
return nil, err
}
lf := &LinkFilter{
db: db,
linkPattern: regexp.MustCompile(`(?i)\b(?:(?:https?://)?(?:(?:www\.)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:t\.me|telegram\.me))(?:/[^\s]*)?)`),
}
if err := lf.LoadDataFromFile(); err != nil {
db.Close() // Close the database if loading fails
return nil, err
}
@ -45,12 +37,12 @@ func (lf *LinkFilter) LoadDataFromFile() error {
defer lf.mu.Unlock()
var err error
lf.keywords, err = lf.db.GetAllKeywords()
lf.keywords, err = core.DB.GetAllKeywords()
if err != nil {
return err
}
lf.whitelist, err = lf.db.GetAllWhitelist()
lf.whitelist, err = core.DB.GetAllWhitelist()
if err != nil {
return err
}
@ -135,7 +127,7 @@ func (lf *LinkFilter) AddKeyword(keyword string) error {
return nil
}
}
err := lf.db.AddKeyword(keyword)
err := core.DB.AddKeyword(keyword)
if err != nil {
return err
}
@ -144,18 +136,19 @@ func (lf *LinkFilter) AddKeyword(keyword string) error {
}
func (lf *LinkFilter) RemoveKeyword(keyword string) bool {
for _, k := range lf.keywords {
if k == keyword {
lf.db.RemoveKeyword(keyword)
lf.LoadDataFromFile()
return true
}
removed, err := core.DB.RemoveKeyword(keyword)
if err != nil {
logger.Printf("Error removing keyword: %v", err)
return false
}
return false
if removed {
lf.LoadDataFromFile()
}
return removed
}
func (lf *LinkFilter) RemoveKeywordsContaining(substring string) ([]string, error) {
removed, err := lf.db.RemoveKeywordsContaining(substring)
removed, err := core.DB.RemoveKeywordsContaining(substring)
if err != nil {
return nil, err
}
@ -211,8 +204,3 @@ func (lf *LinkFilter) containsKeyword(link string) bool {
}
return false
}
// 新增 Close 方法
func (lf *LinkFilter) Close() error {
return lf.db.Close()
}

View File

@ -15,7 +15,7 @@ import (
)
// handleUpdate 处理所有传入的更新信息,包括消息和命令, 然后分开处理。
func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link_filter.LinkFilter, rateLimiter *core.RateLimiter, db *core.Database) {
func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link_filter.LinkFilter, rateLimiter *core.RateLimiter) {
// 检查更新是否包含消息,如果不包含则直接返回。
if update.Message == nil {
return
@ -23,7 +23,7 @@ func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link
// 如果消息来自私聊且发送者是预定义的管理员,调用处理管理员命令的函数。
if update.Message.Chat.Type == "private" && update.Message.From.ID == core.ADMIN_ID {
handleAdminCommand(bot, update.Message, db)
handleAdminCommand(bot, update.Message)
return
}
@ -34,15 +34,15 @@ func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link
}
// 处理管理员私聊消息
func handleAdminCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) {
func handleAdminCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message) {
command := message.Command()
args := message.CommandArguments()
switch command {
case "add", "delete", "list", "deletecontaining":
HandleKeywordCommand(bot, message, command, args, db)
HandleKeywordCommand(bot, message, command, args)
case "addwhite", "delwhite", "listwhite":
HandleWhitelistCommand(bot, message, command, args, db)
HandleWhitelistCommand(bot, message, command, args)
case "prompt":
prompt_reply.HandlePromptCommand(bot, message)
default:
@ -118,11 +118,6 @@ func RunMessageHandler() error {
baseDelay := time.Second
maxDelay := 5 * time.Minute
delay := baseDelay
db, err := core.NewDatabase()
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer db.Close() // 确保在函数结束时关闭数据库连接
for {
err := func() error {
@ -155,7 +150,7 @@ func RunMessageHandler() error {
updates := bot.GetUpdatesChan(u)
for update := range updates {
go handleUpdate(bot, update, linkFilter, rateLimiter, db)
go handleUpdate(bot, update, linkFilter, rateLimiter)
}
return nil
@ -227,25 +222,25 @@ func sendErrorMessage(bot *tgbotapi.BotAPI, chatID int64, errMsg string) {
sendMessage(bot, chatID, errMsg)
}
func HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string, db *core.Database) {
func HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) {
args = strings.TrimSpace(args)
switch command {
case "list":
handleListKeywords(bot, message, db)
handleListKeywords(bot, message)
case "add":
handleAddKeyword(bot, message, args, db)
handleAddKeyword(bot, message, args)
case "delete":
handleDeleteKeyword(bot, message, args, db)
handleDeleteKeyword(bot, message, args)
case "deletecontaining":
handleDeleteContainingKeyword(bot, message, args, db)
handleDeleteContainingKeyword(bot, message, args)
default:
sendErrorMessage(bot, message.Chat.ID, "无效的命令或参数。")
}
}
func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) {
keywords, err := db.GetAllKeywords()
func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message) {
keywords, err := core.DB.GetAllKeywords()
if err != nil {
sendErrorMessage(bot, message.Chat.ID, "获取关键词列表时发生错误。")
return
@ -257,19 +252,19 @@ func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *cor
}
}
func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) {
func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) {
if keyword == "" {
sendErrorMessage(bot, message.Chat.ID, "请提供要添加的关键词。")
return
}
exists, err := db.KeywordExists(keyword)
exists, err := core.DB.KeywordExists(keyword)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, "检查关键词时发生错误。")
return
}
if !exists {
err = db.AddKeyword(keyword)
err = core.DB.AddKeyword(keyword)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, "添加关键词时发生错误。")
} else {
@ -280,33 +275,27 @@ func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword s
}
}
func handleDeleteKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) {
func handleDeleteKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) {
if keyword == "" {
sendErrorMessage(bot, message.Chat.ID, "请提供要删除的关键词。")
return
}
err := db.RemoveKeyword(keyword)
removed, err := core.DB.RemoveKeyword(keyword)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("删除关键词 '%s' 时发生错误: %v", keyword, err))
return
}
exists, err := db.KeywordExists(keyword)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查关键词 '%s' 是否存在时发生错误: %v", keyword, err))
return
}
if !exists {
if removed {
sendMessage(bot, message.Chat.ID, fmt.Sprintf("关键词 '%s' 已成功删除。", keyword))
} else {
handleSimilarKeywords(bot, message, keyword, db)
handleSimilarKeywords(bot, message, keyword)
}
}
func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) {
similarKeywords, err := db.SearchKeywords(keyword)
func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) {
similarKeywords, err := core.DB.SearchKeywords(keyword)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, "搜索关键词时发生错误。")
return
@ -318,13 +307,13 @@ func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyw
}
}
func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, substring string, db *core.Database) {
func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, substring string) {
if substring == "" {
sendErrorMessage(bot, message.Chat.ID, "请提供要删除的子字符串。")
return
}
removedKeywords, err := db.RemoveKeywordsContaining(substring)
removedKeywords, err := core.DB.RemoveKeywordsContaining(substring)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, "删除关键词时发生错误。")
return
@ -336,23 +325,23 @@ func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Messa
}
}
func HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string, db *core.Database) {
func HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) {
args = strings.TrimSpace(args)
switch command {
case "listwhite":
handleListWhitelist(bot, message, db)
handleListWhitelist(bot, message)
case "addwhite":
handleAddWhitelist(bot, message, args, db)
handleAddWhitelist(bot, message, args)
case "delwhite":
handleDeleteWhitelist(bot, message, args, db)
handleDeleteWhitelist(bot, message, args)
default:
sendErrorMessage(bot, message.Chat.ID, "无效的命令或参数。")
}
}
func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) {
whitelist, err := db.GetAllWhitelist()
func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message) {
whitelist, err := core.DB.GetAllWhitelist()
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("获取白名单时发生错误: %v", err))
return
@ -364,14 +353,14 @@ func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *co
}
}
func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string, db *core.Database) {
func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string) {
if domain == "" {
sendErrorMessage(bot, message.Chat.ID, "请提供要添加的域名。")
return
}
domain = strings.ToLower(domain)
exists, err := db.WhitelistExists(domain)
exists, err := core.DB.WhitelistExists(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查白名单时发生错误: %v", err))
return
@ -381,13 +370,13 @@ func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain
return
}
err = db.AddWhitelist(domain)
err = core.DB.AddWhitelist(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("添加到白名单时发生错误: %v", err))
return
}
exists, err = db.WhitelistExists(domain)
exists, err = core.DB.WhitelistExists(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("验证添加操作时发生错误: %v", err))
return
@ -399,14 +388,14 @@ func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain
}
}
func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string, db *core.Database) {
func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string) {
if domain == "" {
sendErrorMessage(bot, message.Chat.ID, "请提供要删除的域名。")
return
}
domain = strings.ToLower(domain)
exists, err := db.WhitelistExists(domain)
exists, err := core.DB.WhitelistExists(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查白名单时发生错误: %v", err))
return
@ -416,13 +405,13 @@ func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, doma
return
}
err = db.RemoveWhitelist(domain)
err = core.DB.RemoveWhitelist(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("从白名单删除时发生错误: %v", err))
return
}
exists, err = db.WhitelistExists(domain)
exists, err = core.DB.WhitelistExists(domain)
if err != nil {
sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("验证删除操作时发生错误: %v", err))
return

View File

@ -4,32 +4,51 @@ import (
"fmt"
"log"
"strings"
"time"
"github.com/woodchen-ink/Q58Bot/core"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
)
var db *core.Database
func init() {
var err error
db, err = core.NewDatabase()
if err != nil {
log.Fatalf("Error initializing database: %v", err)
}
}
func SetPromptReply(prompt, reply string) error {
return db.AddPromptReply(prompt, reply)
err := core.DB.AddPromptReply(prompt, reply)
if err != nil {
log.Printf("提示回复: %s 设置提示回复失败: %v", time.Now().Format("2006/01/02 15:04:05"), err)
return err
}
// 获取当前所有的 prompt replies 来确认添加成功
promptReplies, err := core.DB.GetAllPromptReplies()
if err != nil {
log.Printf("提示回复: %s 添加后获取提示回复列表出错: %v", time.Now().Format("2006/01/02 15:04:05"), err)
} else {
log.Printf("提示回复: %s 设置提示回复成功。当前提示回复数量: %d", time.Now().Format("2006/01/02 15:04:05"), len(promptReplies))
}
return nil
}
func DeletePromptReply(prompt string) error {
return db.DeletePromptReply(prompt)
err := core.DB.DeletePromptReply(prompt)
if err != nil {
log.Printf("提示回复: %s 删除提示回复失败: %v", time.Now().Format("2006/01/02 15:04:05"), err)
return err
}
// 获取当前所有的 prompt replies 来确认删除成功
promptReplies, err := core.DB.GetAllPromptReplies()
if err != nil {
log.Printf("提示回复: %s 删除后获取提示回复列表出错: %v", time.Now().Format("2006/01/02 15:04:05"), err)
} else {
log.Printf("提示回复: %s 删除提示回复成功。当前提示回复数量: %d", time.Now().Format("2006/01/02 15:04:05"), len(promptReplies))
}
return nil
}
func GetPromptReply(message string) (string, bool) {
promptReplies, err := db.GetAllPromptReplies()
promptReplies, err := core.DB.GetAllPromptReplies()
if err != nil {
log.Printf("Error getting prompt replies: %v", err)
return "", false
@ -45,7 +64,7 @@ func GetPromptReply(message string) (string, bool) {
}
func ListPromptReplies() string {
replies, err := db.GetAllPromptReplies()
replies, err := core.DB.GetAllPromptReplies()
if err != nil {
log.Printf("Error getting prompt replies: %v", err)
return "Error retrieving prompt replies"