重构(数据库): 在数据库中引入缓存机制,优化查询性能,增加全文搜索功能。

This commit is contained in:
wood 2024-09-09 20:26:41 +08:00
parent f56395704e
commit 71599408d8
2 changed files with 66 additions and 34 deletions

View File

@ -1,6 +1,7 @@
import sqlite3 import sqlite3
import logging import logging
import os import os
import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -8,19 +9,32 @@ class Database:
def __init__(self, db_file): def __init__(self, db_file):
self.db_file = db_file self.db_file = db_file
os.makedirs(os.path.dirname(db_file), exist_ok=True) os.makedirs(os.path.dirname(db_file), exist_ok=True)
self._keywords_cache = None
self._whitelist_cache = None
self._cache_time = 0
self.create_tables() self.create_tables()
def create_tables(self): def create_tables(self):
with sqlite3.connect(self.db_file) as conn: with sqlite3.connect(self.db_file) as conn:
cursor = conn.cursor() cursor = conn.cursor()
# 创建关键词表并添加索引
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS keywords CREATE TABLE IF NOT EXISTS keywords
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE) (id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
''') ''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_keyword ON keywords(keyword)')
# 创建白名单表并添加索引
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS whitelist CREATE TABLE IF NOT EXISTS whitelist
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE) (id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
''') ''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_domain ON whitelist(domain)')
# 创建全文搜索虚拟表
cursor.execute('''
CREATE VIRTUAL TABLE IF NOT EXISTS keywords_fts USING fts5(keyword)
''')
conn.commit() conn.commit()
def execute_query(self, query, params=None): def execute_query(self, query, params=None):
@ -34,31 +48,48 @@ class Database:
return cursor.fetchall() return cursor.fetchall()
def add_keyword(self, keyword): def add_keyword(self, keyword):
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)" self.execute_query("INSERT OR IGNORE INTO keywords (keyword) VALUES (?)", (keyword,))
self.execute_query(query, (keyword,)) self.execute_query("INSERT OR IGNORE INTO keywords_fts (keyword) VALUES (?)", (keyword,))
self._invalidate_cache()
def remove_keyword(self, keyword): def remove_keyword(self, keyword):
query = "DELETE FROM keywords WHERE keyword = ?" self.execute_query("DELETE FROM keywords WHERE keyword = ?", (keyword,))
self.execute_query(query, (keyword,)) self.execute_query("DELETE FROM keywords_fts WHERE keyword = ?", (keyword,))
self._invalidate_cache()
def get_all_keywords(self): def get_all_keywords(self):
query = "SELECT keyword FROM keywords" current_time = time.time()
results = self.execute_query(query) if self._keywords_cache is None or current_time - self._cache_time > 300: # 5分钟缓存
return [row[0] for row in results] self._keywords_cache = [row[0] for row in self.execute_query("SELECT keyword FROM keywords")]
self._cache_time = current_time
return self._keywords_cache
def remove_keywords_containing(self, substring): def remove_keywords_containing(self, substring):
query = "DELETE FROM keywords WHERE keyword LIKE ?" query = "DELETE FROM keywords WHERE keyword LIKE ?"
return self.execute_query(query, (f"%{substring}%",)) result = self.execute_query(query, (f"%{substring}%",))
self.execute_query("DELETE FROM keywords_fts WHERE keyword LIKE ?", (f"%{substring}%",))
self._invalidate_cache()
return result
def add_whitelist(self, domain): def add_whitelist(self, domain):
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)" self.execute_query("INSERT OR IGNORE INTO whitelist (domain) VALUES (?)", (domain,))
self.execute_query(query, (domain,)) self._invalidate_cache()
def remove_whitelist(self, domain): def remove_whitelist(self, domain):
query = "DELETE FROM whitelist WHERE domain = ?" self.execute_query("DELETE FROM whitelist WHERE domain = ?", (domain,))
self.execute_query(query, (domain,)) self._invalidate_cache()
def get_all_whitelist(self): def get_all_whitelist(self):
query = "SELECT domain FROM whitelist" current_time = time.time()
results = self.execute_query(query) if self._whitelist_cache is None or current_time - self._cache_time > 300: # 5分钟缓存
return [row[0] for row in results] self._whitelist_cache = [row[0] for row in self.execute_query("SELECT domain FROM whitelist")]
self._cache_time = current_time
return self._whitelist_cache
def search_keywords(self, pattern):
return [row[0] for row in self.execute_query("SELECT keyword FROM keywords_fts WHERE keyword MATCH ?", (pattern,))]
def _invalidate_cache(self):
self._keywords_cache = None
self._whitelist_cache = None
self._cache_time = 0

View File

@ -29,7 +29,6 @@ class LinkFilter:
re.VERBOSE | re.IGNORECASE, re.VERBOSE | re.IGNORECASE,
) )
def load_data_from_file(self): def load_data_from_file(self):
self.keywords = self.db.get_all_keywords() self.keywords = self.db.get_all_keywords()
self.whitelist = self.db.get_all_whitelist() self.whitelist = self.db.get_all_whitelist()
@ -56,9 +55,10 @@ class LinkFilter:
if len(parts) > 2: if len(parts) > 2:
domain = '.'.join(parts[-2:]) domain = '.'.join(parts[-2:])
return domain.lower() return domain.lower()
def is_whitelisted(self, link): def is_whitelisted(self, link):
domain = self.extract_domain(link) domain = self.extract_domain(link)
result = domain in self.whitelist result = domain in self.db.get_all_whitelist() # 使用缓存机制
logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}") logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}")
return result return result
@ -66,7 +66,7 @@ class LinkFilter:
if self.link_pattern.match(keyword): if self.link_pattern.match(keyword):
keyword = self.normalize_link(keyword) keyword = self.normalize_link(keyword)
keyword = keyword.lstrip("/") keyword = keyword.lstrip("/")
if keyword not in self.keywords: if keyword not in self.db.get_all_keywords(): # 使用缓存机制
self.db.add_keyword(keyword) self.db.add_keyword(keyword)
logger.info(f"New keyword added: {keyword}") logger.info(f"New keyword added: {keyword}")
self.load_data_from_file() self.load_data_from_file()
@ -74,22 +74,20 @@ class LinkFilter:
logger.debug(f"Keyword already exists: {keyword}") logger.debug(f"Keyword already exists: {keyword}")
def remove_keyword(self, keyword): def remove_keyword(self, keyword):
if keyword in self.keywords: if keyword in self.db.get_all_keywords(): # 使用缓存机制
self.db.remove_keyword(keyword) self.db.remove_keyword(keyword)
self.load_data_from_file() self.load_data_from_file()
return True return True
return False return False
def remove_keywords_containing(self, substring): def remove_keywords_containing(self, substring):
removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()] removed_keywords = self.db.remove_keywords_containing(substring)
for keyword in removed_keywords:
self.db.remove_keyword(keyword)
self.load_data_from_file() self.load_data_from_file()
return removed_keywords return removed_keywords
def should_filter(self, text): def should_filter(self, text):
logger.debug(f"Checking text: {text}") logger.debug(f"Checking text: {text}")
if any(keyword.lower() in text.lower() for keyword in self.keywords): if any(keyword.lower() in text.lower() for keyword in self.db.get_all_keywords()): # 使用缓存机制
logger.info(f"Text contains keyword: {text}") logger.info(f"Text contains keyword: {text}")
return True, [] return True, []
@ -101,7 +99,7 @@ class LinkFilter:
normalized_link = normalized_link.lstrip("/") normalized_link = normalized_link.lstrip("/")
if not self.is_whitelisted(normalized_link): if not self.is_whitelisted(normalized_link):
logger.debug(f"Link not whitelisted: {normalized_link}") logger.debug(f"Link not whitelisted: {normalized_link}")
if normalized_link not in self.keywords: if normalized_link not in self.db.get_all_keywords(): # 使用缓存机制
new_non_whitelisted_links.append(normalized_link) new_non_whitelisted_links.append(normalized_link)
self.add_keyword(normalized_link) self.add_keyword(normalized_link)
else: else:
@ -114,15 +112,14 @@ class LinkFilter:
async def handle_keyword_command(self, event, command, args): async def handle_keyword_command(self, event, command, args):
if command == "/list": if command == "/list":
self.load_data_from_file() keywords = self.db.get_all_keywords() # 使用缓存机制
keywords = self.keywords
if not keywords: if not keywords:
await event.reply("关键词列表为空。") await event.reply("关键词列表为空。")
else: else:
await send_long_message(event, "当前关键词列表:", keywords) await send_long_message(event, "当前关键词列表:", keywords)
elif command == "/add" and args: elif command == "/add" and args:
keyword = " ".join(args) keyword = " ".join(args)
if keyword not in self.keywords: if keyword not in self.db.get_all_keywords(): # 使用缓存机制
self.add_keyword(keyword) self.add_keyword(keyword)
await event.reply(f"关键词 '{keyword}' 已添加。") await event.reply(f"关键词 '{keyword}' 已添加。")
else: else:
@ -132,7 +129,7 @@ class LinkFilter:
if self.remove_keyword(keyword): if self.remove_keyword(keyword):
await event.reply(f"关键词 '{keyword}' 已删除。") await event.reply(f"关键词 '{keyword}' 已删除。")
else: else:
similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()] similar_keywords = self.db.search_keywords(keyword) # 使用模糊搜索
if similar_keywords: if similar_keywords:
await send_long_message( await send_long_message(
event, event,
@ -150,13 +147,19 @@ class LinkFilter:
) )
else: else:
await event.reply(f"没有找到包含 '{substring}' 的关键词。") await event.reply(f"没有找到包含 '{substring}' 的关键词。")
elif command == "/search" and args:
pattern = " ".join(args)
search_results = self.db.search_keywords(pattern)
if search_results:
await send_long_message(event, f"搜索 '{pattern}' 的结果:", search_results)
else:
await event.reply(f"没有找到匹配 '{pattern}' 的关键词。")
else: else:
await event.reply("无效的命令或参数。") await event.reply("无效的命令或参数。")
async def handle_whitelist_command(self, event, command, args): async def handle_whitelist_command(self, event, command, args):
if command == "/listwhite": if command == "/listwhite":
self.load_data_from_file() whitelist = self.db.get_all_whitelist() # 使用缓存机制
whitelist = self.whitelist
await event.reply( await event.reply(
"白名单域名列表:\n" + "\n".join(whitelist) "白名单域名列表:\n" + "\n".join(whitelist)
if whitelist if whitelist
@ -164,17 +167,15 @@ class LinkFilter:
) )
elif command == "/addwhite" and args: elif command == "/addwhite" and args:
domain = args[0].lower() domain = args[0].lower()
if domain not in self.whitelist: if domain not in self.db.get_all_whitelist(): # 使用缓存机制
self.db.add_whitelist(domain) self.db.add_whitelist(domain)
self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已添加到白名单。") await event.reply(f"域名 '{domain}' 已添加到白名单。")
else: else:
await event.reply(f"域名 '{domain}' 已在白名单中。") await event.reply(f"域名 '{domain}' 已在白名单中。")
elif command == "/delwhite" and args: elif command == "/delwhite" and args:
domain = args[0].lower() domain = args[0].lower()
if domain in self.whitelist: if domain in self.db.get_all_whitelist(): # 使用缓存机制
self.db.remove_whitelist(domain) self.db.remove_whitelist(domain)
self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已从白名单中删除。") await event.reply(f"域名 '{domain}' 已从白名单中删除。")
else: else:
await event.reply(f"域名 '{domain}' 不在白名单中。") await event.reply(f"域名 '{domain}' 不在白名单中。")