mirror of
https://github.com/woodchen-ink/Q58Bot.git
synced 2025-07-18 13:52:07 +08:00
数据库连接优化及数据加载改进
对Database类进行重构,使用上下文管理器处理数据库连接,以简化代码并提高可读性。在初始化时确保数据库文件目录存在,以防止文件路径错误。此外,改进了LinkFilter类,从数据库加载数据,而非在初始化时立即加载,以提高灵活性。 另外,迁移脚本中增加了日志记录,以提高操作的可见性,并处理潜在的错误,以增强脚本的健壮性。
This commit is contained in:
parent
8a2ade4ea4
commit
52f902354e
@ -1,54 +1,37 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, db_file):
|
def __init__(self, db_file):
|
||||||
self.db_file = db_file
|
self.db_file = db_file
|
||||||
self.conn = None
|
os.makedirs(os.path.dirname(db_file), exist_ok=True)
|
||||||
self.create_tables()
|
self.create_tables()
|
||||||
|
|
||||||
def create_tables(self):
|
def create_tables(self):
|
||||||
try:
|
with sqlite3.connect(self.db_file) as conn:
|
||||||
self.conn = sqlite3.connect(self.db_file)
|
cursor = conn.cursor()
|
||||||
cursor = self.conn.cursor()
|
cursor.execute('''
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS keywords
|
CREATE TABLE IF NOT EXISTS keywords
|
||||||
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
|
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
|
||||||
"""
|
''')
|
||||||
)
|
cursor.execute('''
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS whitelist
|
CREATE TABLE IF NOT EXISTS whitelist
|
||||||
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
|
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
|
||||||
"""
|
''')
|
||||||
)
|
conn.commit()
|
||||||
self.conn.commit()
|
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"Database error: {e}")
|
|
||||||
finally:
|
|
||||||
if self.conn:
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
def execute_query(self, query, params=None):
|
def execute_query(self, query, params=None):
|
||||||
try:
|
with sqlite3.connect(self.db_file) as conn:
|
||||||
self.conn = sqlite3.connect(self.db_file)
|
cursor = conn.cursor()
|
||||||
cursor = self.conn.cursor()
|
|
||||||
if params:
|
if params:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
else:
|
else:
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
self.conn.commit()
|
conn.commit()
|
||||||
return cursor
|
return cursor.fetchall()
|
||||||
except sqlite3.Error as e:
|
|
||||||
logger.error(f"Query execution error: {e}")
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
if self.conn:
|
|
||||||
self.conn.close()
|
|
||||||
|
|
||||||
def add_keyword(self, keyword):
|
def add_keyword(self, keyword):
|
||||||
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
|
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
|
||||||
@ -60,13 +43,12 @@ class Database:
|
|||||||
|
|
||||||
def get_all_keywords(self):
|
def get_all_keywords(self):
|
||||||
query = "SELECT keyword FROM keywords"
|
query = "SELECT keyword FROM keywords"
|
||||||
cursor = self.execute_query(query)
|
results = self.execute_query(query)
|
||||||
return [row[0] for row in cursor] if cursor else []
|
return [row[0] for row in results]
|
||||||
|
|
||||||
def remove_keywords_containing(self, substring):
|
def remove_keywords_containing(self, substring):
|
||||||
query = "DELETE FROM keywords WHERE keyword LIKE ?"
|
query = "DELETE FROM keywords WHERE keyword LIKE ?"
|
||||||
cursor = self.execute_query(query, (f"%{substring}%",))
|
return self.execute_query(query, (f"%{substring}%",))
|
||||||
return cursor.rowcount if cursor else 0
|
|
||||||
|
|
||||||
def add_whitelist(self, domain):
|
def add_whitelist(self, domain):
|
||||||
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
|
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
|
||||||
@ -78,5 +60,5 @@ class Database:
|
|||||||
|
|
||||||
def get_all_whitelist(self):
|
def get_all_whitelist(self):
|
||||||
query = "SELECT domain FROM whitelist"
|
query = "SELECT domain FROM whitelist"
|
||||||
cursor = self.execute_query(query)
|
results = self.execute_query(query)
|
||||||
return [row[0] for row in cursor] if cursor else []
|
return [row[0] for row in results]
|
||||||
|
@ -9,8 +9,7 @@ logger = logging.getLogger("TeleGuard.LinkFilter")
|
|||||||
class LinkFilter:
|
class LinkFilter:
|
||||||
def __init__(self, db_file):
|
def __init__(self, db_file):
|
||||||
self.db = Database(db_file)
|
self.db = Database(db_file)
|
||||||
self.keywords = self.db.get_all_keywords()
|
self.load_data_from_file()
|
||||||
self.whitelist = self.db.get_all_whitelist()
|
|
||||||
|
|
||||||
self.link_pattern = re.compile(
|
self.link_pattern = re.compile(
|
||||||
r"""
|
r"""
|
||||||
@ -30,11 +29,12 @@ class LinkFilter:
|
|||||||
re.VERBOSE | re.IGNORECASE,
|
re.VERBOSE | re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_data_from_file(self):
|
def load_data_from_file(self):
|
||||||
self.keywords = self.db.get_all_keywords()
|
self.keywords = self.db.get_all_keywords()
|
||||||
self.whitelist = self.db.get_all_whitelist()
|
self.whitelist = self.db.get_all_whitelist()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries"
|
f"Loaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries from database"
|
||||||
)
|
)
|
||||||
|
|
||||||
def normalize_link(self, link):
|
def normalize_link(self, link):
|
||||||
|
@ -1,29 +1,46 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from database import Database
|
from database import Database
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def migrate_data(json_file, db_file):
|
def migrate_data(json_file, db_file):
|
||||||
|
try:
|
||||||
# 确保 data 目录存在
|
# 确保 data 目录存在
|
||||||
os.makedirs(os.path.dirname(db_file), exist_ok=True)
|
os.makedirs(os.path.dirname(db_file), exist_ok=True)
|
||||||
|
logger.info(f"Ensuring directory exists: {os.path.dirname(db_file)}")
|
||||||
|
|
||||||
# 创建数据库连接
|
# 创建数据库连接
|
||||||
db = Database(db_file)
|
db = Database(db_file)
|
||||||
|
logger.info(f"Database connection created: {db_file}")
|
||||||
|
|
||||||
# 读取 JSON 文件
|
# 读取 JSON 文件
|
||||||
with open(json_file, 'r') as f:
|
with open(json_file, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
logger.info(f"JSON file loaded: {json_file}")
|
||||||
|
|
||||||
# 迁移关键词
|
# 迁移关键词
|
||||||
for keyword in data.get('keywords', []):
|
keywords = data.get('keywords', [])
|
||||||
|
for keyword in keywords:
|
||||||
db.add_keyword(keyword)
|
db.add_keyword(keyword)
|
||||||
|
logger.info(f"Migrated {len(keywords)} keywords")
|
||||||
|
|
||||||
# 迁移白名单
|
# 迁移白名单
|
||||||
for domain in data.get('whitelist', []):
|
whitelist = data.get('whitelist', [])
|
||||||
|
for domain in whitelist:
|
||||||
db.add_whitelist(domain)
|
db.add_whitelist(domain)
|
||||||
|
logger.info(f"Migrated {len(whitelist)} whitelist entries")
|
||||||
|
|
||||||
print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}个")
|
logger.info(f"Migration complete. Keywords: {len(keywords)}, Whitelist: {len(whitelist)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"An error occurred during migration: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
json_file = 'keywords.json' # 旧的 JSON 文件路径
|
json_file = '/app/data/keywords.json' # 旧的 JSON 文件路径
|
||||||
db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径
|
db_file = '/app/data/q58.db' # 新的数据库文件路径
|
||||||
migrate_data(json_file, db_file)
|
migrate_data(json_file, db_file)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user