From 71599408d80b8fe57c281d92a20c4ae9e52083f5 Mon Sep 17 00:00:00 2001 From: wood Date: Mon, 9 Sep 2024 20:26:41 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84(=E6=95=B0=E6=8D=AE=E5=BA=93)?= =?UTF-8?q?:=20=E5=9C=A8=E6=95=B0=E6=8D=AE=E5=BA=93=E4=B8=AD=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E7=BC=93=E5=AD=98=E6=9C=BA=E5=88=B6=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=9F=A5=E8=AF=A2=E6=80=A7=E8=83=BD=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=85=A8=E6=96=87=E6=90=9C=E7=B4=A2=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/database.py | 61 ++++++++++++++++++++++++++++++++++------------ src/link_filter.py | 39 ++++++++++++++--------------- 2 files changed, 66 insertions(+), 34 deletions(-) diff --git a/src/database.py b/src/database.py index 368cf17..90e1aca 100644 --- a/src/database.py +++ b/src/database.py @@ -1,6 +1,7 @@ import sqlite3 import logging import os +import time logger = logging.getLogger(__name__) @@ -8,19 +9,32 @@ class Database: def __init__(self, db_file): self.db_file = db_file os.makedirs(os.path.dirname(db_file), exist_ok=True) + self._keywords_cache = None + self._whitelist_cache = None + self._cache_time = 0 self.create_tables() def create_tables(self): with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() + # 创建关键词表并添加索引 cursor.execute(''' CREATE TABLE IF NOT EXISTS keywords (id INTEGER PRIMARY KEY, keyword TEXT UNIQUE) ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_keyword ON keywords(keyword)') + + # 创建白名单表并添加索引 cursor.execute(''' CREATE TABLE IF NOT EXISTS whitelist (id INTEGER PRIMARY KEY, domain TEXT UNIQUE) ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_domain ON whitelist(domain)') + + # 创建全文搜索虚拟表 + cursor.execute(''' + CREATE VIRTUAL TABLE IF NOT EXISTS keywords_fts USING fts5(keyword) + ''') conn.commit() def execute_query(self, query, params=None): @@ -34,31 +48,48 @@ class Database: return cursor.fetchall() def add_keyword(self, keyword): - query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)" - self.execute_query(query, (keyword,)) + self.execute_query("INSERT OR IGNORE INTO keywords (keyword) VALUES (?)", (keyword,)) + self.execute_query("INSERT OR IGNORE INTO keywords_fts (keyword) VALUES (?)", (keyword,)) + self._invalidate_cache() def remove_keyword(self, keyword): - query = "DELETE FROM keywords WHERE keyword = ?" - self.execute_query(query, (keyword,)) + self.execute_query("DELETE FROM keywords WHERE keyword = ?", (keyword,)) + self.execute_query("DELETE FROM keywords_fts WHERE keyword = ?", (keyword,)) + self._invalidate_cache() def get_all_keywords(self): - query = "SELECT keyword FROM keywords" - results = self.execute_query(query) - return [row[0] for row in results] + current_time = time.time() + if self._keywords_cache is None or current_time - self._cache_time > 300: # 5分钟缓存 + self._keywords_cache = [row[0] for row in self.execute_query("SELECT keyword FROM keywords")] + self._cache_time = current_time + return self._keywords_cache def remove_keywords_containing(self, substring): query = "DELETE FROM keywords WHERE keyword LIKE ?" - return self.execute_query(query, (f"%{substring}%",)) + result = self.execute_query(query, (f"%{substring}%",)) + self.execute_query("DELETE FROM keywords_fts WHERE keyword LIKE ?", (f"%{substring}%",)) + self._invalidate_cache() + return result def add_whitelist(self, domain): - query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)" - self.execute_query(query, (domain,)) + self.execute_query("INSERT OR IGNORE INTO whitelist (domain) VALUES (?)", (domain,)) + self._invalidate_cache() def remove_whitelist(self, domain): - query = "DELETE FROM whitelist WHERE domain = ?" - self.execute_query(query, (domain,)) + self.execute_query("DELETE FROM whitelist WHERE domain = ?", (domain,)) + self._invalidate_cache() def get_all_whitelist(self): - query = "SELECT domain FROM whitelist" - results = self.execute_query(query) - return [row[0] for row in results] + current_time = time.time() + if self._whitelist_cache is None or current_time - self._cache_time > 300: # 5分钟缓存 + self._whitelist_cache = [row[0] for row in self.execute_query("SELECT domain FROM whitelist")] + self._cache_time = current_time + return self._whitelist_cache + + def search_keywords(self, pattern): + return [row[0] for row in self.execute_query("SELECT keyword FROM keywords_fts WHERE keyword MATCH ?", (pattern,))] + + def _invalidate_cache(self): + self._keywords_cache = None + self._whitelist_cache = None + self._cache_time = 0 diff --git a/src/link_filter.py b/src/link_filter.py index d0bdcf2..fbcb89f 100644 --- a/src/link_filter.py +++ b/src/link_filter.py @@ -29,7 +29,6 @@ class LinkFilter: re.VERBOSE | re.IGNORECASE, ) - def load_data_from_file(self): self.keywords = self.db.get_all_keywords() self.whitelist = self.db.get_all_whitelist() @@ -56,9 +55,10 @@ class LinkFilter: if len(parts) > 2: domain = '.'.join(parts[-2:]) return domain.lower() + def is_whitelisted(self, link): domain = self.extract_domain(link) - result = domain in self.whitelist + result = domain in self.db.get_all_whitelist() # 使用缓存机制 logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}") return result @@ -66,7 +66,7 @@ class LinkFilter: if self.link_pattern.match(keyword): keyword = self.normalize_link(keyword) keyword = keyword.lstrip("/") - if keyword not in self.keywords: + if keyword not in self.db.get_all_keywords(): # 使用缓存机制 self.db.add_keyword(keyword) logger.info(f"New keyword added: {keyword}") self.load_data_from_file() @@ -74,22 +74,20 @@ class LinkFilter: logger.debug(f"Keyword already exists: {keyword}") def remove_keyword(self, keyword): - if keyword in self.keywords: + if keyword in self.db.get_all_keywords(): # 使用缓存机制 self.db.remove_keyword(keyword) self.load_data_from_file() return True return False def remove_keywords_containing(self, substring): - removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()] - for keyword in removed_keywords: - self.db.remove_keyword(keyword) + removed_keywords = self.db.remove_keywords_containing(substring) self.load_data_from_file() return removed_keywords def should_filter(self, text): logger.debug(f"Checking text: {text}") - if any(keyword.lower() in text.lower() for keyword in self.keywords): + if any(keyword.lower() in text.lower() for keyword in self.db.get_all_keywords()): # 使用缓存机制 logger.info(f"Text contains keyword: {text}") return True, [] @@ -101,7 +99,7 @@ class LinkFilter: normalized_link = normalized_link.lstrip("/") if not self.is_whitelisted(normalized_link): logger.debug(f"Link not whitelisted: {normalized_link}") - if normalized_link not in self.keywords: + if normalized_link not in self.db.get_all_keywords(): # 使用缓存机制 new_non_whitelisted_links.append(normalized_link) self.add_keyword(normalized_link) else: @@ -114,15 +112,14 @@ class LinkFilter: async def handle_keyword_command(self, event, command, args): if command == "/list": - self.load_data_from_file() - keywords = self.keywords + keywords = self.db.get_all_keywords() # 使用缓存机制 if not keywords: await event.reply("关键词列表为空。") else: await send_long_message(event, "当前关键词列表:", keywords) elif command == "/add" and args: keyword = " ".join(args) - if keyword not in self.keywords: + if keyword not in self.db.get_all_keywords(): # 使用缓存机制 self.add_keyword(keyword) await event.reply(f"关键词 '{keyword}' 已添加。") else: @@ -132,7 +129,7 @@ class LinkFilter: if self.remove_keyword(keyword): await event.reply(f"关键词 '{keyword}' 已删除。") else: - similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()] + similar_keywords = self.db.search_keywords(keyword) # 使用模糊搜索 if similar_keywords: await send_long_message( event, @@ -150,13 +147,19 @@ class LinkFilter: ) else: await event.reply(f"没有找到包含 '{substring}' 的关键词。") + elif command == "/search" and args: + pattern = " ".join(args) + search_results = self.db.search_keywords(pattern) + if search_results: + await send_long_message(event, f"搜索 '{pattern}' 的结果:", search_results) + else: + await event.reply(f"没有找到匹配 '{pattern}' 的关键词。") else: await event.reply("无效的命令或参数。") async def handle_whitelist_command(self, event, command, args): if command == "/listwhite": - self.load_data_from_file() - whitelist = self.whitelist + whitelist = self.db.get_all_whitelist() # 使用缓存机制 await event.reply( "白名单域名列表:\n" + "\n".join(whitelist) if whitelist @@ -164,17 +167,15 @@ class LinkFilter: ) elif command == "/addwhite" and args: domain = args[0].lower() - if domain not in self.whitelist: + if domain not in self.db.get_all_whitelist(): # 使用缓存机制 self.db.add_whitelist(domain) - self.load_data_from_file() await event.reply(f"域名 '{domain}' 已添加到白名单。") else: await event.reply(f"域名 '{domain}' 已在白名单中。") elif command == "/delwhite" and args: domain = args[0].lower() - if domain in self.whitelist: + if domain in self.db.get_all_whitelist(): # 使用缓存机制 self.db.remove_whitelist(domain) - self.load_data_from_file() await event.reply(f"域名 '{domain}' 已从白名单中删除。") else: await event.reply(f"域名 '{domain}' 不在白名单中。")