diff --git a/core/database.go b/core/database.go index 48482d1..55adee2 100644 --- a/core/database.go +++ b/core/database.go @@ -120,17 +120,37 @@ func (d *Database) GetAllKeywords() ([]string, error) { return d.keywordsCache, nil } -func (d *Database) RemoveKeywordsContaining(substring string) error { - _, err := d.db.Exec("DELETE FROM keywords WHERE keyword LIKE ?", "%"+substring+"%") +func (d *Database) RemoveKeywordsContaining(substring string) ([]string, error) { + // 首先获取要删除的关键词列表 + rows, err := d.db.Query("SELECT keyword FROM keywords WHERE keyword LIKE ?", "%"+substring+"%") if err != nil { - return err + return nil, err } + defer rows.Close() + + var removedKeywords []string + for rows.Next() { + var keyword string + if err := rows.Scan(&keyword); err != nil { + return nil, err + } + removedKeywords = append(removedKeywords, keyword) + } + + // 执行删除操作 + _, err = d.db.Exec("DELETE FROM keywords WHERE keyword LIKE ?", "%"+substring+"%") + if err != nil { + return nil, err + } + + // 从 FTS 表中也删除这些关键词 _, err = d.db.Exec("DELETE FROM keywords_fts WHERE keyword LIKE ?", "%"+substring+"%") if err != nil { - return err + return nil, err } + d.invalidateCache() - return nil + return removedKeywords, nil } func (d *Database) AddWhitelist(domain string) error { @@ -171,6 +191,24 @@ func (d *Database) SearchKeywords(pattern string) ([]string, error) { return d.executeQuery("SELECT keyword FROM keywords_fts WHERE keyword MATCH ?", pattern) } +func (d *Database) KeywordExists(keyword string) (bool, error) { + var count int + err := d.db.QueryRow("SELECT COUNT(*) FROM keywords WHERE keyword = ?", keyword).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (d *Database) WhitelistExists(domain string) (bool, error) { + var count int + err := d.db.QueryRow("SELECT COUNT(*) FROM whitelist WHERE domain = ?", domain).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + func (d *Database) invalidateCache() { d.mu.Lock() defer d.mu.Unlock() diff --git a/core/link_filter.go b/core/link_filter.go index e1f8767..f041b29 100644 --- a/core/link_filter.go +++ b/core/link_filter.go @@ -19,19 +19,34 @@ type LinkFilter struct { linkPattern *regexp.Regexp } -func NewLinkFilter(dbFile string) *LinkFilter { +func NewLinkFilter(dbFile string) (*LinkFilter, error) { + db, err := NewDatabase(dbFile) + if err != nil { + return nil, err + } lf := &LinkFilter{ - db: NewDatabase(dbFile), + db: db, } lf.linkPattern = regexp.MustCompile(`(?i)\b(?:(?:https?://)?(?:(?:www\.)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:t\.me|telegram\.me))(?:/[^\s]*)?)`) - lf.LoadDataFromFile() - return lf + err = lf.LoadDataFromFile() + if err != nil { + return nil, err + } + return lf, nil } -func (lf *LinkFilter) LoadDataFromFile() { - lf.keywords = lf.db.GetAllKeywords() - lf.whitelist = lf.db.GetAllWhitelist() +func (lf *LinkFilter) LoadDataFromFile() error { + var err error + lf.keywords, err = lf.db.GetAllKeywords() + if err != nil { + return err + } + lf.whitelist, err = lf.db.GetAllWhitelist() + if err != nil { + return err + } logger.Printf("Loaded %d keywords and %d whitelist entries from database", len(lf.keywords), len(lf.whitelist)) + return nil } func (lf *LinkFilter) NormalizeLink(link string) string { @@ -74,7 +89,7 @@ func (lf *LinkFilter) IsWhitelisted(link string) bool { return false } -func (lf *LinkFilter) AddKeyword(keyword string) { +func (lf *LinkFilter) AddKeyword(keyword string) error { if lf.linkPattern.MatchString(keyword) { keyword = lf.NormalizeLink(keyword) } @@ -82,12 +97,15 @@ func (lf *LinkFilter) AddKeyword(keyword string) { for _, k := range lf.keywords { if k == keyword { logger.Printf("Keyword already exists: %s", keyword) - return + return nil } } - lf.db.AddKeyword(keyword) + err := lf.db.AddKeyword(keyword) + if err != nil { + return err + } logger.Printf("New keyword added: %s", keyword) - lf.LoadDataFromFile() + return lf.LoadDataFromFile() } func (lf *LinkFilter) RemoveKeyword(keyword string) bool { @@ -101,10 +119,16 @@ func (lf *LinkFilter) RemoveKeyword(keyword string) bool { return false } -func (lf *LinkFilter) RemoveKeywordsContaining(substring string) []string { - removed := lf.db.RemoveKeywordsContaining(substring) - lf.LoadDataFromFile() - return removed +func (lf *LinkFilter) RemoveKeywordsContaining(substring string) ([]string, error) { + removed, err := lf.db.RemoveKeywordsContaining(substring) + if err != nil { + return nil, err + } + err = lf.LoadDataFromFile() + if err != nil { + return nil, err + } + return removed, nil } func (lf *LinkFilter) ShouldFilter(text string) (bool, []string) { @@ -147,7 +171,11 @@ func (lf *LinkFilter) ShouldFilter(text string) (bool, []string) { func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { switch command { case "list": - keywords := lf.db.GetAllKeywords() + keywords, err := lf.db.GetAllKeywords() + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "获取关键词列表时发生错误。")) + return + } if len(keywords) == 0 { bot.Send(tgbotapi.NewMessage(message.Chat.ID, "关键词列表为空。")) } else { @@ -156,20 +184,34 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota case "add": if args != "" { keyword := args - if !lf.db.KeywordExists(keyword) { - lf.AddKeyword(keyword) + exists, err := lf.db.KeywordExists(keyword) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查关键词时发生错误。")) + return + } + if !exists { + err = lf.AddKeyword(keyword) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "添加关键词时发生错误。")) + return + } bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已添加。", keyword))) } else { bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已存在。", keyword))) } } + case "delete": if args != "" { keyword := args if lf.RemoveKeyword(keyword) { bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已删除。", keyword))) } else { - similarKeywords := lf.db.SearchKeywords(keyword) + similarKeywords, err := lf.db.SearchKeywords(keyword) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "搜索关键词时发生错误。")) + return + } if len(similarKeywords) > 0 { SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("未找到精确匹配的关键词 '%s'。\n\n以下是相似的关键词:", keyword), similarKeywords) } else { @@ -180,7 +222,11 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota case "deletecontaining": if args != "" { substring := args - removedKeywords := lf.RemoveKeywordsContaining(substring) + removedKeywords, err := lf.RemoveKeywordsContaining(substring) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "删除关键词时发生错误。")) + return + } if len(removedKeywords) > 0 { SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("已删除包含 '%s' 的以下关键词:", substring), removedKeywords) } else { @@ -195,17 +241,30 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota func (lf *LinkFilter) HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { switch command { case "listwhite": - whitelist := lf.db.GetAllWhitelist() + whitelist, err := lf.db.GetAllWhitelist() + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "获取白名单时发生错误。")) + return + } if len(whitelist) == 0 { bot.Send(tgbotapi.NewMessage(message.Chat.ID, "白名单为空。")) } else { - SendLongMessageWithoutNumbering(bot, message.Chat.ID, "白名单域名列表:", whitelist) + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "白名单域名列表:\n"+strings.Join(whitelist, "\n"))) } case "addwhite": if args != "" { domain := strings.ToLower(args) - if !lf.db.WhitelistExists(domain) { - lf.db.AddWhitelist(domain) + exists, err := lf.db.WhitelistExists(domain) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查白名单时发生错误。")) + return + } + if !exists { + err = lf.db.AddWhitelist(domain) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "添加到白名单时发生错误。")) + return + } bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已添加到白名单。", domain))) } else { bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已在白名单中。", domain))) @@ -214,8 +273,17 @@ func (lf *LinkFilter) HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbo case "delwhite": if args != "" { domain := strings.ToLower(args) - if lf.db.WhitelistExists(domain) { - lf.db.RemoveWhitelist(domain) + exists, err := lf.db.WhitelistExists(domain) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查白名单时发生错误。")) + return + } + if exists { + err = lf.db.RemoveWhitelist(domain) + if err != nil { + bot.Send(tgbotapi.NewMessage(message.Chat.ID, "从白名单删除时发生错误。")) + return + } bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已从白名单中删除。", domain))) } else { bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 不在白名单中。", domain)))