This commit is contained in:
wood chen 2024-09-18 01:34:22 +08:00
parent 8056410bf6
commit 68c2651908
2 changed files with 137 additions and 31 deletions

View File

@ -120,17 +120,37 @@ func (d *Database) GetAllKeywords() ([]string, error) {
return d.keywordsCache, nil return d.keywordsCache, nil
} }
func (d *Database) RemoveKeywordsContaining(substring string) error { func (d *Database) RemoveKeywordsContaining(substring string) ([]string, error) {
_, err := d.db.Exec("DELETE FROM keywords WHERE keyword LIKE ?", "%"+substring+"%") // 首先获取要删除的关键词列表
rows, err := d.db.Query("SELECT keyword FROM keywords WHERE keyword LIKE ?", "%"+substring+"%")
if err != nil { if err != nil {
return err return nil, err
} }
defer rows.Close()
var removedKeywords []string
for rows.Next() {
var keyword string
if err := rows.Scan(&keyword); err != nil {
return nil, err
}
removedKeywords = append(removedKeywords, keyword)
}
// 执行删除操作
_, err = d.db.Exec("DELETE FROM keywords WHERE keyword LIKE ?", "%"+substring+"%")
if err != nil {
return nil, err
}
// 从 FTS 表中也删除这些关键词
_, err = d.db.Exec("DELETE FROM keywords_fts WHERE keyword LIKE ?", "%"+substring+"%") _, err = d.db.Exec("DELETE FROM keywords_fts WHERE keyword LIKE ?", "%"+substring+"%")
if err != nil { if err != nil {
return err return nil, err
} }
d.invalidateCache() d.invalidateCache()
return nil return removedKeywords, nil
} }
func (d *Database) AddWhitelist(domain string) error { func (d *Database) AddWhitelist(domain string) error {
@ -171,6 +191,24 @@ func (d *Database) SearchKeywords(pattern string) ([]string, error) {
return d.executeQuery("SELECT keyword FROM keywords_fts WHERE keyword MATCH ?", pattern) return d.executeQuery("SELECT keyword FROM keywords_fts WHERE keyword MATCH ?", pattern)
} }
func (d *Database) KeywordExists(keyword string) (bool, error) {
var count int
err := d.db.QueryRow("SELECT COUNT(*) FROM keywords WHERE keyword = ?", keyword).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) WhitelistExists(domain string) (bool, error) {
var count int
err := d.db.QueryRow("SELECT COUNT(*) FROM whitelist WHERE domain = ?", domain).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) invalidateCache() { func (d *Database) invalidateCache() {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()

View File

@ -19,19 +19,34 @@ type LinkFilter struct {
linkPattern *regexp.Regexp linkPattern *regexp.Regexp
} }
func NewLinkFilter(dbFile string) *LinkFilter { func NewLinkFilter(dbFile string) (*LinkFilter, error) {
db, err := NewDatabase(dbFile)
if err != nil {
return nil, err
}
lf := &LinkFilter{ lf := &LinkFilter{
db: NewDatabase(dbFile), db: db,
} }
lf.linkPattern = regexp.MustCompile(`(?i)\b(?:(?:https?://)?(?:(?:www\.)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:t\.me|telegram\.me))(?:/[^\s]*)?)`) lf.linkPattern = regexp.MustCompile(`(?i)\b(?:(?:https?://)?(?:(?:www\.)?(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}|(?:t\.me|telegram\.me))(?:/[^\s]*)?)`)
lf.LoadDataFromFile() err = lf.LoadDataFromFile()
return lf if err != nil {
return nil, err
}
return lf, nil
} }
func (lf *LinkFilter) LoadDataFromFile() { func (lf *LinkFilter) LoadDataFromFile() error {
lf.keywords = lf.db.GetAllKeywords() var err error
lf.whitelist = lf.db.GetAllWhitelist() lf.keywords, err = lf.db.GetAllKeywords()
if err != nil {
return err
}
lf.whitelist, err = lf.db.GetAllWhitelist()
if err != nil {
return err
}
logger.Printf("Loaded %d keywords and %d whitelist entries from database", len(lf.keywords), len(lf.whitelist)) logger.Printf("Loaded %d keywords and %d whitelist entries from database", len(lf.keywords), len(lf.whitelist))
return nil
} }
func (lf *LinkFilter) NormalizeLink(link string) string { func (lf *LinkFilter) NormalizeLink(link string) string {
@ -74,7 +89,7 @@ func (lf *LinkFilter) IsWhitelisted(link string) bool {
return false return false
} }
func (lf *LinkFilter) AddKeyword(keyword string) { func (lf *LinkFilter) AddKeyword(keyword string) error {
if lf.linkPattern.MatchString(keyword) { if lf.linkPattern.MatchString(keyword) {
keyword = lf.NormalizeLink(keyword) keyword = lf.NormalizeLink(keyword)
} }
@ -82,12 +97,15 @@ func (lf *LinkFilter) AddKeyword(keyword string) {
for _, k := range lf.keywords { for _, k := range lf.keywords {
if k == keyword { if k == keyword {
logger.Printf("Keyword already exists: %s", keyword) logger.Printf("Keyword already exists: %s", keyword)
return return nil
} }
} }
lf.db.AddKeyword(keyword) err := lf.db.AddKeyword(keyword)
if err != nil {
return err
}
logger.Printf("New keyword added: %s", keyword) logger.Printf("New keyword added: %s", keyword)
lf.LoadDataFromFile() return lf.LoadDataFromFile()
} }
func (lf *LinkFilter) RemoveKeyword(keyword string) bool { func (lf *LinkFilter) RemoveKeyword(keyword string) bool {
@ -101,10 +119,16 @@ func (lf *LinkFilter) RemoveKeyword(keyword string) bool {
return false return false
} }
func (lf *LinkFilter) RemoveKeywordsContaining(substring string) []string { func (lf *LinkFilter) RemoveKeywordsContaining(substring string) ([]string, error) {
removed := lf.db.RemoveKeywordsContaining(substring) removed, err := lf.db.RemoveKeywordsContaining(substring)
lf.LoadDataFromFile() if err != nil {
return removed return nil, err
}
err = lf.LoadDataFromFile()
if err != nil {
return nil, err
}
return removed, nil
} }
func (lf *LinkFilter) ShouldFilter(text string) (bool, []string) { func (lf *LinkFilter) ShouldFilter(text string) (bool, []string) {
@ -147,7 +171,11 @@ func (lf *LinkFilter) ShouldFilter(text string) (bool, []string) {
func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) {
switch command { switch command {
case "list": case "list":
keywords := lf.db.GetAllKeywords() keywords, err := lf.db.GetAllKeywords()
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "获取关键词列表时发生错误。"))
return
}
if len(keywords) == 0 { if len(keywords) == 0 {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "关键词列表为空。")) bot.Send(tgbotapi.NewMessage(message.Chat.ID, "关键词列表为空。"))
} else { } else {
@ -156,20 +184,34 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota
case "add": case "add":
if args != "" { if args != "" {
keyword := args keyword := args
if !lf.db.KeywordExists(keyword) { exists, err := lf.db.KeywordExists(keyword)
lf.AddKeyword(keyword) if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查关键词时发生错误。"))
return
}
if !exists {
err = lf.AddKeyword(keyword)
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "添加关键词时发生错误。"))
return
}
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已添加。", keyword))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已添加。", keyword)))
} else { } else {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已存在。", keyword))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已存在。", keyword)))
} }
} }
case "delete": case "delete":
if args != "" { if args != "" {
keyword := args keyword := args
if lf.RemoveKeyword(keyword) { if lf.RemoveKeyword(keyword) {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已删除。", keyword))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("关键词 '%s' 已删除。", keyword)))
} else { } else {
similarKeywords := lf.db.SearchKeywords(keyword) similarKeywords, err := lf.db.SearchKeywords(keyword)
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "搜索关键词时发生错误。"))
return
}
if len(similarKeywords) > 0 { if len(similarKeywords) > 0 {
SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("未找到精确匹配的关键词 '%s'。\n\n以下是相似的关键词", keyword), similarKeywords) SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("未找到精确匹配的关键词 '%s'。\n\n以下是相似的关键词", keyword), similarKeywords)
} else { } else {
@ -180,7 +222,11 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota
case "deletecontaining": case "deletecontaining":
if args != "" { if args != "" {
substring := args substring := args
removedKeywords := lf.RemoveKeywordsContaining(substring) removedKeywords, err := lf.RemoveKeywordsContaining(substring)
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "删除关键词时发生错误。"))
return
}
if len(removedKeywords) > 0 { if len(removedKeywords) > 0 {
SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("已删除包含 '%s' 的以下关键词:", substring), removedKeywords) SendLongMessage(bot, message.Chat.ID, fmt.Sprintf("已删除包含 '%s' 的以下关键词:", substring), removedKeywords)
} else { } else {
@ -195,17 +241,30 @@ func (lf *LinkFilter) HandleKeywordCommand(bot *tgbotapi.BotAPI, message *tgbota
func (lf *LinkFilter) HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) { func (lf *LinkFilter) HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbotapi.Message, command string, args string) {
switch command { switch command {
case "listwhite": case "listwhite":
whitelist := lf.db.GetAllWhitelist() whitelist, err := lf.db.GetAllWhitelist()
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "获取白名单时发生错误。"))
return
}
if len(whitelist) == 0 { if len(whitelist) == 0 {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "白名单为空。")) bot.Send(tgbotapi.NewMessage(message.Chat.ID, "白名单为空。"))
} else { } else {
SendLongMessageWithoutNumbering(bot, message.Chat.ID, "白名单域名列表:", whitelist) bot.Send(tgbotapi.NewMessage(message.Chat.ID, "白名单域名列表:\n"+strings.Join(whitelist, "\n")))
} }
case "addwhite": case "addwhite":
if args != "" { if args != "" {
domain := strings.ToLower(args) domain := strings.ToLower(args)
if !lf.db.WhitelistExists(domain) { exists, err := lf.db.WhitelistExists(domain)
lf.db.AddWhitelist(domain) if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查白名单时发生错误。"))
return
}
if !exists {
err = lf.db.AddWhitelist(domain)
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "添加到白名单时发生错误。"))
return
}
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已添加到白名单。", domain))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已添加到白名单。", domain)))
} else { } else {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已在白名单中。", domain))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已在白名单中。", domain)))
@ -214,8 +273,17 @@ func (lf *LinkFilter) HandleWhitelistCommand(bot *tgbotapi.BotAPI, message *tgbo
case "delwhite": case "delwhite":
if args != "" { if args != "" {
domain := strings.ToLower(args) domain := strings.ToLower(args)
if lf.db.WhitelistExists(domain) { exists, err := lf.db.WhitelistExists(domain)
lf.db.RemoveWhitelist(domain) if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "检查白名单时发生错误。"))
return
}
if exists {
err = lf.db.RemoveWhitelist(domain)
if err != nil {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, "从白名单删除时发生错误。"))
return
}
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已从白名单中删除。", domain))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 已从白名单中删除。", domain)))
} else { } else {
bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 不在白名单中。", domain))) bot.Send(tgbotapi.NewMessage(message.Chat.ID, fmt.Sprintf("域名 '%s' 不在白名单中。", domain)))