diff --git a/core/database.go b/core/database.go index 37caf6f..f9134cf 100644 --- a/core/database.go +++ b/core/database.go @@ -91,13 +91,17 @@ func (d *Database) AddKeyword(keyword string) error { return nil } -func (d *Database) RemoveKeyword(keyword string) error { - _, err := d.db.Exec("DELETE FROM keywords WHERE keyword = ?", keyword) +func (d *Database) RemoveKeyword(keyword string) (bool, error) { + result, err := d.db.Exec("DELETE FROM keywords WHERE keyword = ?", keyword) if err != nil { - return err + return false, err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return false, err } d.invalidateCache("keywords") - return nil + return rowsAffected > 0, nil } func (d *Database) GetAllKeywords() ([]string, error) { @@ -200,19 +204,41 @@ func (d *Database) WhitelistExists(domain string) (bool, error) { } func (d *Database) AddPromptReply(prompt, reply string) error { - _, err := d.db.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply) + tx, err := d.db.Begin() if err != nil { return err } + defer tx.Rollback() + + _, err = tx.Exec("INSERT OR REPLACE INTO prompt_replies (prompt, reply) VALUES (?, ?)", strings.ToLower(prompt), reply) + if err != nil { + return err + } + + if err = tx.Commit(); err != nil { + return err + } + d.invalidateCache("promptReplies") return nil } func (d *Database) DeletePromptReply(prompt string) error { - _, err := d.db.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt)) + tx, err := d.db.Begin() if err != nil { return err } + defer tx.Rollback() + + _, err = tx.Exec("DELETE FROM prompt_replies WHERE prompt = ?", strings.ToLower(prompt)) + if err != nil { + return err + } + + if err = tx.Commit(); err != nil { + return err + } + d.invalidateCache("promptReplies") return nil } @@ -239,16 +265,15 @@ func (d *Database) GetAllPromptReplies() (map[string]string, error) { d.mu.Lock() defer d.mu.Unlock() - if d.promptRepliesCache == nil || time.Since(d.promptRepliesCacheTime) > 5*time.Minute { - promptReplies, err := d.fetchAllPromptReplies() - if err != nil { - return nil, err - } - d.promptRepliesCache = promptReplies - d.promptRepliesCacheTime = time.Now() + // 强制刷新缓存 + promptReplies, err := d.fetchAllPromptReplies() + if err != nil { + return nil, err } + d.promptRepliesCache = promptReplies + d.promptRepliesCacheTime = time.Now() - // 返回一个副本以防止外部修改缓存 + // 返回一个副本 result := make(map[string]string, len(d.promptRepliesCache)) for k, v := range d.promptRepliesCache { result[k] = v diff --git a/core/init.go b/core/init.go index 5b70812..6c59e7a 100644 --- a/core/init.go +++ b/core/init.go @@ -22,6 +22,7 @@ var ( DB_FILE string DEBUG_MODE bool err error + DB *Database ) func IsAdmin(userID int64) bool { @@ -52,20 +53,12 @@ func Init() error { adminIDStr := os.Getenv("ADMIN_ID") ADMIN_ID, err = mustParseInt64(adminIDStr) if err != nil { - return fmt.Errorf("Invalid ADMIN_ID: %v", err) + return fmt.Errorf("invalid ADMIN_ID: %v", err) } - // 初始化 Bot API - Bot, err = tgbotapi.NewBotAPI(BOT_TOKEN) - if err != nil { - return fmt.Errorf("创建 Bot API 失败: %v", err) - } - - log.Printf("账户已授权 %s", Bot.Self.UserName) - // 初始化数据库 DB_FILE = filepath.Join("/app/data", "q58.db") - _, err = NewDatabase() + DB, err = NewDatabase() if err != nil { return fmt.Errorf("初始化数据库失败: %v", err) } @@ -81,7 +74,7 @@ func Init() error { chatIDStr := os.Getenv("CHAT_ID") ChatID, err = mustParseInt64(chatIDStr) if err != nil { - return fmt.Errorf("Invalid CHAT_ID: %v", err) + return fmt.Errorf("invalid CHAT_ID: %v", err) } // 初始化 Symbols diff --git a/main.go b/main.go index e70a9cd..ca44895 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ func main() { if err != nil { log.Fatalf("Failed to initialize service: %v", err) } + defer core.DB.Close() // 确保在程序退出时关闭数据库连接 go binance.RunBinance() diff --git a/service/link_filter/link_filter.go b/service/link_filter/link_filter.go index 348a65f..54bd78d 100644 --- a/service/link_filter/link_filter.go +++ b/service/link_filter/link_filter.go @@ -14,7 +14,6 @@ import ( var logger = log.New(log.Writer(), "LinkFilter: ", log.Ldate|log.Ltime|log.Lshortfile) type LinkFilter struct { - db *core.Database keywords []string whitelist []string linkPattern *regexp.Regexp @@ -22,18 +21,11 @@ type LinkFilter struct { } func NewLinkFilter() (*LinkFilter, error) { - db, err := core.NewDatabase() - if err != nil { - return nil, err - } - lf := &LinkFilter{ - db: db, linkPattern: regexp.MustCompile(`(?i)\b(?:(?:https?://)?(?:(?:www\.)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:t\.me|telegram\.me))(?:/[^\s]*)?)`), } if err := lf.LoadDataFromFile(); err != nil { - db.Close() // Close the database if loading fails return nil, err } @@ -45,12 +37,12 @@ func (lf *LinkFilter) LoadDataFromFile() error { defer lf.mu.Unlock() var err error - lf.keywords, err = lf.db.GetAllKeywords() + lf.keywords, err = core.DB.GetAllKeywords() if err != nil { return err } - lf.whitelist, err = lf.db.GetAllWhitelist() + lf.whitelist, err = core.DB.GetAllWhitelist() if err != nil { return err } @@ -135,7 +127,7 @@ func (lf *LinkFilter) AddKeyword(keyword string) error { return nil } } - err := lf.db.AddKeyword(keyword) + err := core.DB.AddKeyword(keyword) if err != nil { return err } @@ -144,18 +136,19 @@ func (lf *LinkFilter) AddKeyword(keyword string) error { } func (lf *LinkFilter) RemoveKeyword(keyword string) bool { - for _, k := range lf.keywords { - if k == keyword { - lf.db.RemoveKeyword(keyword) - lf.LoadDataFromFile() - return true - } + removed, err := core.DB.RemoveKeyword(keyword) + if err != nil { + logger.Printf("Error removing keyword: %v", err) + return false } - return false + if removed { + lf.LoadDataFromFile() + } + return removed } func (lf *LinkFilter) RemoveKeywordsContaining(substring string) ([]string, error) { - removed, err := lf.db.RemoveKeywordsContaining(substring) + removed, err := core.DB.RemoveKeywordsContaining(substring) if err != nil { return nil, err } @@ -211,8 +204,3 @@ func (lf *LinkFilter) containsKeyword(link string) bool { } return false } - -// 新增 Close 方法 -func (lf *LinkFilter) Close() error { - return lf.db.Close() -} diff --git a/service/message_handler.go b/service/message_handler.go index 2961c29..99dc8ac 100644 --- a/service/message_handler.go +++ b/service/message_handler.go @@ -15,7 +15,7 @@ import ( ) // handleUpdate 处理所有传入的更新信息,包括消息和命令, 然后分开处理。 -func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link_filter.LinkFilter, rateLimiter *core.RateLimiter, db *core.Database) { +func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link_filter.LinkFilter, rateLimiter *core.RateLimiter) { // 检查更新是否包含消息,如果不包含则直接返回。 if update.Message == nil { return @@ -23,7 +23,7 @@ func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link // 如果消息来自私聊且发送者是预定义的管理员,调用处理管理员命令的函数。 if update.Message.Chat.Type == "private" && update.Message.From.ID == core.ADMIN_ID { - handleAdminCommand(bot, update.Message, db) + handleAdminCommand(bot, update.Message) return } @@ -34,15 +34,15 @@ func handleUpdate(bot *tgbotapi.BotAPI, update tgbotapi.Update, linkFilter *link } // 处理管理员私聊消息 -func handleAdminCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) { +func handleAdminCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message) { command := message.Command() args := message.CommandArguments() switch command { case "add", "delete", "list", "deletecontaining": - HandleKeywordCommand(bot, message, command, args, db) + HandleKeywordCommand(bot, message, command, args) case "addwhite", "delwhite", "listwhite": - HandleWhitelistCommand(bot, message, command, args, db) + HandleWhitelistCommand(bot, message, command, args) case "prompt": prompt_reply.HandlePromptCommand(bot, message) default: @@ -118,11 +118,6 @@ func RunMessageHandler() error { baseDelay := time.Second maxDelay := 5 * time.Minute delay := baseDelay - db, err := core.NewDatabase() - if err != nil { - return fmt.Errorf("failed to initialize database: %w", err) - } - defer db.Close() // 确保在函数结束时关闭数据库连接 for { err := func() error { @@ -155,7 +150,7 @@ func RunMessageHandler() error { updates := bot.GetUpdatesChan(u) for update := range updates { - go handleUpdate(bot, update, linkFilter, rateLimiter, db) + go handleUpdate(bot, update, linkFilter, rateLimiter) } return nil @@ -227,25 +222,25 @@ func sendErrorMessage(bot *tgbotapi.BotAPI, chatID int64, errMsg string) { sendMessage(bot, chatID, errMsg) } -func HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string, db *core.Database) { +func HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { args = strings.TrimSpace(args) switch command { case "list": - handleListKeywords(bot, message, db) + handleListKeywords(bot, message) case "add": - handleAddKeyword(bot, message, args, db) + handleAddKeyword(bot, message, args) case "delete": - handleDeleteKeyword(bot, message, args, db) + handleDeleteKeyword(bot, message, args) case "deletecontaining": - handleDeleteContainingKeyword(bot, message, args, db) + handleDeleteContainingKeyword(bot, message, args) default: sendErrorMessage(bot, message.Chat.ID, "无效的命令或参数。") } } -func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) { - keywords, err := db.GetAllKeywords() +func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message) { + keywords, err := core.DB.GetAllKeywords() if err != nil { sendErrorMessage(bot, message.Chat.ID, "获取关键词列表时发生错误。") return @@ -257,19 +252,19 @@ func handleListKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *cor } } -func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) { +func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) { if keyword == "" { sendErrorMessage(bot, message.Chat.ID, "请提供要添加的关键词。") return } - exists, err := db.KeywordExists(keyword) + exists, err := core.DB.KeywordExists(keyword) if err != nil { sendErrorMessage(bot, message.Chat.ID, "检查关键词时发生错误。") return } if !exists { - err = db.AddKeyword(keyword) + err = core.DB.AddKeyword(keyword) if err != nil { sendErrorMessage(bot, message.Chat.ID, "添加关键词时发生错误。") } else { @@ -280,33 +275,27 @@ func handleAddKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword s } } -func handleDeleteKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) { +func handleDeleteKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) { if keyword == "" { sendErrorMessage(bot, message.Chat.ID, "请提供要删除的关键词。") return } - err := db.RemoveKeyword(keyword) + removed, err := core.DB.RemoveKeyword(keyword) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("删除关键词 '%s' 时发生错误: %v", keyword, err)) return } - exists, err := db.KeywordExists(keyword) - if err != nil { - sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查关键词 '%s' 是否存在时发生错误: %v", keyword, err)) - return - } - - if !exists { + if removed { sendMessage(bot, message.Chat.ID, fmt.Sprintf("关键词 '%s' 已成功删除。", keyword)) } else { - handleSimilarKeywords(bot, message, keyword, db) + handleSimilarKeywords(bot, message, keyword) } } -func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string, db *core.Database) { - similarKeywords, err := db.SearchKeywords(keyword) +func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyword string) { + similarKeywords, err := core.DB.SearchKeywords(keyword) if err != nil { sendErrorMessage(bot, message.Chat.ID, "搜索关键词时发生错误。") return @@ -318,13 +307,13 @@ func handleSimilarKeywords(bot *tgbotapi.BotAPI, message *tgbotapi.Message, keyw } } -func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, substring string, db *core.Database) { +func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Message, substring string) { if substring == "" { sendErrorMessage(bot, message.Chat.ID, "请提供要删除的子字符串。") return } - removedKeywords, err := db.RemoveKeywordsContaining(substring) + removedKeywords, err := core.DB.RemoveKeywordsContaining(substring) if err != nil { sendErrorMessage(bot, message.Chat.ID, "删除关键词时发生错误。") return @@ -336,23 +325,23 @@ func handleDeleteContainingKeyword(bot *tgbotapi.BotAPI, message *tgbotapi.Messa } } -func HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string, db *core.Database) { +func HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { args = strings.TrimSpace(args) switch command { case "listwhite": - handleListWhitelist(bot, message, db) + handleListWhitelist(bot, message) case "addwhite": - handleAddWhitelist(bot, message, args, db) + handleAddWhitelist(bot, message, args) case "delwhite": - handleDeleteWhitelist(bot, message, args, db) + handleDeleteWhitelist(bot, message, args) default: sendErrorMessage(bot, message.Chat.ID, "无效的命令或参数。") } } -func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *core.Database) { - whitelist, err := db.GetAllWhitelist() +func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message) { + whitelist, err := core.DB.GetAllWhitelist() if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("获取白名单时发生错误: %v", err)) return @@ -364,14 +353,14 @@ func handleListWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, db *co } } -func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string, db *core.Database) { +func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string) { if domain == "" { sendErrorMessage(bot, message.Chat.ID, "请提供要添加的域名。") return } domain = strings.ToLower(domain) - exists, err := db.WhitelistExists(domain) + exists, err := core.DB.WhitelistExists(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查白名单时发生错误: %v", err)) return @@ -381,13 +370,13 @@ func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain return } - err = db.AddWhitelist(domain) + err = core.DB.AddWhitelist(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("添加到白名单时发生错误: %v", err)) return } - exists, err = db.WhitelistExists(domain) + exists, err = core.DB.WhitelistExists(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("验证添加操作时发生错误: %v", err)) return @@ -399,14 +388,14 @@ func handleAddWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain } } -func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string, db *core.Database) { +func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, domain string) { if domain == "" { sendErrorMessage(bot, message.Chat.ID, "请提供要删除的域名。") return } domain = strings.ToLower(domain) - exists, err := db.WhitelistExists(domain) + exists, err := core.DB.WhitelistExists(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("检查白名单时发生错误: %v", err)) return @@ -416,13 +405,13 @@ func handleDeleteWhitelist(bot *tgbotapi.BotAPI, message *tgbotapi.Message, doma return } - err = db.RemoveWhitelist(domain) + err = core.DB.RemoveWhitelist(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("从白名单删除时发生错误: %v", err)) return } - exists, err = db.WhitelistExists(domain) + exists, err = core.DB.WhitelistExists(domain) if err != nil { sendErrorMessage(bot, message.Chat.ID, fmt.Sprintf("验证删除操作时发生错误: %v", err)) return diff --git a/service/prompt_reply/prompt_reply.go b/service/prompt_reply/prompt_reply.go index 4935bc7..3570609 100644 --- a/service/prompt_reply/prompt_reply.go +++ b/service/prompt_reply/prompt_reply.go @@ -4,32 +4,51 @@ import ( "fmt" "log" "strings" + "time" "github.com/woodchen-ink/Q58Bot/core" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" ) -var db *core.Database - -func init() { - var err error - db, err = core.NewDatabase() - if err != nil { - log.Fatalf("Error initializing database: %v", err) - } -} - func SetPromptReply(prompt, reply string) error { - return db.AddPromptReply(prompt, reply) + err := core.DB.AddPromptReply(prompt, reply) + if err != nil { + log.Printf("提示回复: %s 设置提示回复失败: %v", time.Now().Format("2006/01/02 15:04:05"), err) + return err + } + + // 获取当前所有的 prompt replies 来确认添加成功 + promptReplies, err := core.DB.GetAllPromptReplies() + if err != nil { + log.Printf("提示回复: %s 添加后获取提示回复列表出错: %v", time.Now().Format("2006/01/02 15:04:05"), err) + } else { + log.Printf("提示回复: %s 设置提示回复成功。当前提示回复数量: %d", time.Now().Format("2006/01/02 15:04:05"), len(promptReplies)) + } + + return nil } func DeletePromptReply(prompt string) error { - return db.DeletePromptReply(prompt) + err := core.DB.DeletePromptReply(prompt) + if err != nil { + log.Printf("提示回复: %s 删除提示回复失败: %v", time.Now().Format("2006/01/02 15:04:05"), err) + return err + } + + // 获取当前所有的 prompt replies 来确认删除成功 + promptReplies, err := core.DB.GetAllPromptReplies() + if err != nil { + log.Printf("提示回复: %s 删除后获取提示回复列表出错: %v", time.Now().Format("2006/01/02 15:04:05"), err) + } else { + log.Printf("提示回复: %s 删除提示回复成功。当前提示回复数量: %d", time.Now().Format("2006/01/02 15:04:05"), len(promptReplies)) + } + + return nil } func GetPromptReply(message string) (string, bool) { - promptReplies, err := db.GetAllPromptReplies() + promptReplies, err := core.DB.GetAllPromptReplies() if err != nil { log.Printf("Error getting prompt replies: %v", err) return "", false @@ -45,7 +64,7 @@ func GetPromptReply(message string) (string, bool) { } func ListPromptReplies() string { - replies, err := db.GetAllPromptReplies() + replies, err := core.DB.GetAllPromptReplies() if err != nil { log.Printf("Error getting prompt replies: %v", err) return "Error retrieving prompt replies"