数据库连接优化及数据加载改进

对Database类进行重构,使用上下文管理器处理数据库连接,以简化代码并提高可读性。在初始化时确保数据库文件目录存在,以防止文件路径错误。此外,改进了LinkFilter类,从数据库加载数据,而非在初始化时立即加载,以提高灵活性。

另外,迁移脚本中增加了日志记录,以提高操作的可见性,并处理潜在的错误,以增强脚本的健壮性。
This commit is contained in:
wood 2024-09-09 20:07:15 +08:00
parent 8a2ade4ea4
commit 52f902354e
3 changed files with 54 additions and 55 deletions

View File

@ -1,54 +1,37 @@
import sqlite3 import sqlite3
import logging import logging
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Database: class Database:
def __init__(self, db_file): def __init__(self, db_file):
self.db_file = db_file self.db_file = db_file
self.conn = None os.makedirs(os.path.dirname(db_file), exist_ok=True)
self.create_tables() self.create_tables()
def create_tables(self): def create_tables(self):
try: with sqlite3.connect(self.db_file) as conn:
self.conn = sqlite3.connect(self.db_file) cursor = conn.cursor()
cursor = self.conn.cursor() cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS keywords CREATE TABLE IF NOT EXISTS keywords
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE) (id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
""" ''')
) cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS whitelist CREATE TABLE IF NOT EXISTS whitelist
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE) (id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
""" ''')
) conn.commit()
self.conn.commit()
except sqlite3.Error as e:
logger.error(f"Database error: {e}")
finally:
if self.conn:
self.conn.close()
def execute_query(self, query, params=None): def execute_query(self, query, params=None):
try: with sqlite3.connect(self.db_file) as conn:
self.conn = sqlite3.connect(self.db_file) cursor = conn.cursor()
cursor = self.conn.cursor()
if params: if params:
cursor.execute(query, params) cursor.execute(query, params)
else: else:
cursor.execute(query) cursor.execute(query)
self.conn.commit() conn.commit()
return cursor return cursor.fetchall()
except sqlite3.Error as e:
logger.error(f"Query execution error: {e}")
return None
finally:
if self.conn:
self.conn.close()
def add_keyword(self, keyword): def add_keyword(self, keyword):
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)" query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
@ -60,13 +43,12 @@ class Database:
def get_all_keywords(self): def get_all_keywords(self):
query = "SELECT keyword FROM keywords" query = "SELECT keyword FROM keywords"
cursor = self.execute_query(query) results = self.execute_query(query)
return [row[0] for row in cursor] if cursor else [] return [row[0] for row in results]
def remove_keywords_containing(self, substring): def remove_keywords_containing(self, substring):
query = "DELETE FROM keywords WHERE keyword LIKE ?" query = "DELETE FROM keywords WHERE keyword LIKE ?"
cursor = self.execute_query(query, (f"%{substring}%",)) return self.execute_query(query, (f"%{substring}%",))
return cursor.rowcount if cursor else 0
def add_whitelist(self, domain): def add_whitelist(self, domain):
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)" query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
@ -78,5 +60,5 @@ class Database:
def get_all_whitelist(self): def get_all_whitelist(self):
query = "SELECT domain FROM whitelist" query = "SELECT domain FROM whitelist"
cursor = self.execute_query(query) results = self.execute_query(query)
return [row[0] for row in cursor] if cursor else [] return [row[0] for row in results]

View File

@ -9,8 +9,7 @@ logger = logging.getLogger("TeleGuard.LinkFilter")
class LinkFilter: class LinkFilter:
def __init__(self, db_file): def __init__(self, db_file):
self.db = Database(db_file) self.db = Database(db_file)
self.keywords = self.db.get_all_keywords() self.load_data_from_file()
self.whitelist = self.db.get_all_whitelist()
self.link_pattern = re.compile( self.link_pattern = re.compile(
r""" r"""
@ -30,11 +29,12 @@ class LinkFilter:
re.VERBOSE | re.IGNORECASE, re.VERBOSE | re.IGNORECASE,
) )
def load_data_from_file(self): def load_data_from_file(self):
self.keywords = self.db.get_all_keywords() self.keywords = self.db.get_all_keywords()
self.whitelist = self.db.get_all_whitelist() self.whitelist = self.db.get_all_whitelist()
logger.info( logger.info(
f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries" f"Loaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries from database"
) )
def normalize_link(self, link): def normalize_link(self, link):

View File

@ -1,29 +1,46 @@
import json import json
import os import os
import logging
from database import Database from database import Database
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def migrate_data(json_file, db_file): def migrate_data(json_file, db_file):
try:
# 确保 data 目录存在 # 确保 data 目录存在
os.makedirs(os.path.dirname(db_file), exist_ok=True) os.makedirs(os.path.dirname(db_file), exist_ok=True)
logger.info(f"Ensuring directory exists: {os.path.dirname(db_file)}")
# 创建数据库连接 # 创建数据库连接
db = Database(db_file) db = Database(db_file)
logger.info(f"Database connection created: {db_file}")
# 读取 JSON 文件 # 读取 JSON 文件
with open(json_file, 'r') as f: with open(json_file, 'r') as f:
data = json.load(f) data = json.load(f)
logger.info(f"JSON file loaded: {json_file}")
# 迁移关键词 # 迁移关键词
for keyword in data.get('keywords', []): keywords = data.get('keywords', [])
for keyword in keywords:
db.add_keyword(keyword) db.add_keyword(keyword)
logger.info(f"Migrated {len(keywords)} keywords")
# 迁移白名单 # 迁移白名单
for domain in data.get('whitelist', []): whitelist = data.get('whitelist', [])
for domain in whitelist:
db.add_whitelist(domain) db.add_whitelist(domain)
logger.info(f"Migrated {len(whitelist)} whitelist entries")
print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}") logger.info(f"Migration complete. Keywords: {len(keywords)}, Whitelist: {len(whitelist)}")
except Exception as e:
logger.error(f"An error occurred during migration: {str(e)}")
raise
if __name__ == "__main__": if __name__ == "__main__":
json_file = 'keywords.json' # 旧的 JSON 文件路径 json_file = '/app/data/keywords.json' # 旧的 JSON 文件路径
db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径 db_file = '/app/data/q58.db' # 新的数据库文件路径
migrate_data(json_file, db_file) migrate_data(json_file, db_file)