From 52f902354e5d43c17f8c68132a48bf73ebc74c9e Mon Sep 17 00:00:00 2001 From: wood Date: Mon, 9 Sep 2024 20:07:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8F=8A=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E6=94=B9=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对Database类进行重构,使用上下文管理器处理数据库连接,以简化代码并提高可读性。在初始化时确保数据库文件目录存在,以防止文件路径错误。此外,改进了LinkFilter类,从数据库加载数据,而非在初始化时立即加载,以提高灵活性。 另外,迁移脚本中增加了日志记录,以提高操作的可见性,并处理潜在的错误,以增强脚本的健壮性。 --- src/database.py | 54 ++++++++++++++++------------------------------ src/link_filter.py | 6 +++--- src/migrate.py | 49 +++++++++++++++++++++++++++-------------- 3 files changed, 54 insertions(+), 55 deletions(-) diff --git a/src/database.py b/src/database.py index 6793c25..368cf17 100644 --- a/src/database.py +++ b/src/database.py @@ -1,54 +1,37 @@ import sqlite3 import logging +import os logger = logging.getLogger(__name__) - class Database: def __init__(self, db_file): self.db_file = db_file - self.conn = None + os.makedirs(os.path.dirname(db_file), exist_ok=True) self.create_tables() def create_tables(self): - try: - self.conn = sqlite3.connect(self.db_file) - cursor = self.conn.cursor() - cursor.execute( - """ + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(''' CREATE TABLE IF NOT EXISTS keywords (id INTEGER PRIMARY KEY, keyword TEXT UNIQUE) - """ - ) - cursor.execute( - """ + ''') + cursor.execute(''' CREATE TABLE IF NOT EXISTS whitelist (id INTEGER PRIMARY KEY, domain TEXT UNIQUE) - """ - ) - self.conn.commit() - except sqlite3.Error as e: - logger.error(f"Database error: {e}") - finally: - if self.conn: - self.conn.close() + ''') + conn.commit() def execute_query(self, query, params=None): - try: - self.conn = sqlite3.connect(self.db_file) - cursor = self.conn.cursor() + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() if params: cursor.execute(query, params) else: cursor.execute(query) - self.conn.commit() - return cursor - except sqlite3.Error as e: - logger.error(f"Query execution error: {e}") - return None - finally: - if self.conn: - self.conn.close() + conn.commit() + return cursor.fetchall() def add_keyword(self, keyword): query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)" @@ -60,13 +43,12 @@ class Database: def get_all_keywords(self): query = "SELECT keyword FROM keywords" - cursor = self.execute_query(query) - return [row[0] for row in cursor] if cursor else [] + results = self.execute_query(query) + return [row[0] for row in results] def remove_keywords_containing(self, substring): query = "DELETE FROM keywords WHERE keyword LIKE ?" - cursor = self.execute_query(query, (f"%{substring}%",)) - return cursor.rowcount if cursor else 0 + return self.execute_query(query, (f"%{substring}%",)) def add_whitelist(self, domain): query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)" @@ -78,5 +60,5 @@ class Database: def get_all_whitelist(self): query = "SELECT domain FROM whitelist" - cursor = self.execute_query(query) - return [row[0] for row in cursor] if cursor else [] + results = self.execute_query(query) + return [row[0] for row in results] diff --git a/src/link_filter.py b/src/link_filter.py index f65f372..d0bdcf2 100644 --- a/src/link_filter.py +++ b/src/link_filter.py @@ -9,8 +9,7 @@ logger = logging.getLogger("TeleGuard.LinkFilter") class LinkFilter: def __init__(self, db_file): self.db = Database(db_file) - self.keywords = self.db.get_all_keywords() - self.whitelist = self.db.get_all_whitelist() + self.load_data_from_file() self.link_pattern = re.compile( r""" @@ -30,11 +29,12 @@ class LinkFilter: re.VERBOSE | re.IGNORECASE, ) + def load_data_from_file(self): self.keywords = self.db.get_all_keywords() self.whitelist = self.db.get_all_whitelist() logger.info( - f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries" + f"Loaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries from database" ) def normalize_link(self, link): diff --git a/src/migrate.py b/src/migrate.py index 9169dc4..2e7ef48 100644 --- a/src/migrate.py +++ b/src/migrate.py @@ -1,29 +1,46 @@ import json import os +import logging from database import Database +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + def migrate_data(json_file, db_file): - # 确保 data 目录存在 - os.makedirs(os.path.dirname(db_file), exist_ok=True) + try: + # 确保 data 目录存在 + os.makedirs(os.path.dirname(db_file), exist_ok=True) + logger.info(f"Ensuring directory exists: {os.path.dirname(db_file)}") - # 创建数据库连接 - db = Database(db_file) + # 创建数据库连接 + db = Database(db_file) + logger.info(f"Database connection created: {db_file}") - # 读取 JSON 文件 - with open(json_file, 'r') as f: - data = json.load(f) + # 读取 JSON 文件 + with open(json_file, 'r') as f: + data = json.load(f) + logger.info(f"JSON file loaded: {json_file}") - # 迁移关键词 - for keyword in data.get('keywords', []): - db.add_keyword(keyword) + # 迁移关键词 + keywords = data.get('keywords', []) + for keyword in keywords: + db.add_keyword(keyword) + logger.info(f"Migrated {len(keywords)} keywords") - # 迁移白名单 - for domain in data.get('whitelist', []): - db.add_whitelist(domain) + # 迁移白名单 + whitelist = data.get('whitelist', []) + for domain in whitelist: + db.add_whitelist(domain) + logger.info(f"Migrated {len(whitelist)} whitelist entries") - print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}个") + logger.info(f"Migration complete. Keywords: {len(keywords)}, Whitelist: {len(whitelist)}") + + except Exception as e: + logger.error(f"An error occurred during migration: {str(e)}") + raise if __name__ == "__main__": - json_file = 'keywords.json' # 旧的 JSON 文件路径 - db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径 + json_file = '/app/data/keywords.json' # 旧的 JSON 文件路径 + db_file = '/app/data/q58.db' # 新的数据库文件路径 migrate_data(json_file, db_file)