diff --git a/requirements.txt b/requirements.txt index 0ff539c..924e2c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,3 @@ ccxt pyTelegramBotAPI schedule pytz -tldextract diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..6793c25 --- /dev/null +++ b/src/database.py @@ -0,0 +1,82 @@ +import sqlite3 +import logging + +logger = logging.getLogger(__name__) + + +class Database: + def __init__(self, db_file): + self.db_file = db_file + self.conn = None + self.create_tables() + + def create_tables(self): + try: + self.conn = sqlite3.connect(self.db_file) + cursor = self.conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS keywords + (id INTEGER PRIMARY KEY, keyword TEXT UNIQUE) + """ + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS whitelist + (id INTEGER PRIMARY KEY, domain TEXT UNIQUE) + """ + ) + self.conn.commit() + except sqlite3.Error as e: + logger.error(f"Database error: {e}") + finally: + if self.conn: + self.conn.close() + + def execute_query(self, query, params=None): + try: + self.conn = sqlite3.connect(self.db_file) + cursor = self.conn.cursor() + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + self.conn.commit() + return cursor + except sqlite3.Error as e: + logger.error(f"Query execution error: {e}") + return None + finally: + if self.conn: + self.conn.close() + + def add_keyword(self, keyword): + query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)" + self.execute_query(query, (keyword,)) + + def remove_keyword(self, keyword): + query = "DELETE FROM keywords WHERE keyword = ?" + self.execute_query(query, (keyword,)) + + def get_all_keywords(self): + query = "SELECT keyword FROM keywords" + cursor = self.execute_query(query) + return [row[0] for row in cursor] if cursor else [] + + def remove_keywords_containing(self, substring): + query = "DELETE FROM keywords WHERE keyword LIKE ?" + cursor = self.execute_query(query, (f"%{substring}%",)) + return cursor.rowcount if cursor else 0 + + def add_whitelist(self, domain): + query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)" + self.execute_query(query, (domain,)) + + def remove_whitelist(self, domain): + query = "DELETE FROM whitelist WHERE domain = ?" + self.execute_query(query, (domain,)) + + def get_all_whitelist(self): + query = "SELECT domain FROM whitelist" + cursor = self.execute_query(query) + return [row[0] for row in cursor] if cursor else [] diff --git a/src/guard.py b/src/guard.py index c523eed..b2dc7ca 100644 --- a/src/guard.py +++ b/src/guard.py @@ -8,13 +8,10 @@ import time from link_filter import LinkFilter from bot_commands import register_commands - - # 环境变量 BOT_TOKEN = os.environ.get('BOT_TOKEN') ADMIN_ID = int(os.environ.get('ADMIN_ID')) -KEYWORDS_FILE = '/app/data/keywords.json' -WHITELIST_FILE = '/app/data/whitelist.json' +DB_FILE = '/app/data/q58.db' # 新的数据库文件路径 # 设置日志 DEBUG_MODE = os.environ.get('DEBUG_MODE', 'False').lower() == 'true' @@ -30,7 +27,7 @@ link_filter_logger.setLevel(logging.DEBUG if DEBUG_MODE else logging.INFO) logging.getLogger('telethon').setLevel(logging.WARNING) # 创建 LinkFilter 实例 -link_filter = LinkFilter(KEYWORDS_FILE, WHITELIST_FILE) +link_filter = LinkFilter(DB_FILE) class RateLimiter: def __init__(self, max_calls, period): @@ -73,7 +70,6 @@ async def process_message(event, client): if new_links: logger.info(f"New non-whitelisted links found: {new_links}") - async def message_handler(event, link_filter, rate_limiter): if not event.is_private or event.sender_id != ADMIN_ID: async with rate_limiter: @@ -94,6 +90,7 @@ async def command_handler(event, link_filter): if event.raw_text.startswith(('/add', '/delete', '/deletecontaining','/list', '/addwhite', '/delwhite', '/listwhite')): link_filter.load_data_from_file() + async def start_bot(): async with TelegramClient('bot', api_id=6, api_hash='eb06d4abfb49dc3eeb1aeb98ae0f581e') as client: await client.start(bot_token=BOT_TOKEN) diff --git a/src/link_filter.py b/src/link_filter.py index f80d146..f65f372 100644 --- a/src/link_filter.py +++ b/src/link_filter.py @@ -1,20 +1,16 @@ import re -import json -import tldextract import urllib.parse import logging +from database import Database from functions import send_long_message logger = logging.getLogger("TeleGuard.LinkFilter") - class LinkFilter: - def __init__(self, keywords_file, whitelist_file): - self.keywords_file = keywords_file - self.whitelist_file = whitelist_file - self.keywords = [] - self.whitelist = [] - self.load_data_from_file() + def __init__(self, db_file): + self.db = Database(db_file) + self.keywords = self.db.get_all_keywords() + self.whitelist = self.db.get_all_whitelist() self.link_pattern = re.compile( r""" @@ -34,49 +30,34 @@ class LinkFilter: re.VERBOSE | re.IGNORECASE, ) - def load_json(self, file_path): - try: - with open(file_path, "r") as f: - return json.load(f) - except FileNotFoundError: - return [] - - def save_json(self, file_path, data): - with open(file_path, "w") as f: - json.dump(data, f) - - def save_keywords(self): - self.save_json(self.keywords_file, self.keywords) - - def save_whitelist(self): - self.save_json(self.whitelist_file, self.whitelist) - def load_data_from_file(self): - self.keywords = self.load_json(self.keywords_file) - self.whitelist = self.load_json(self.whitelist_file) + self.keywords = self.db.get_all_keywords() + self.whitelist = self.db.get_all_whitelist() logger.info( f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries" ) def normalize_link(self, link): - # 移除协议部分(如 http:// 或 https://) link = re.sub(r"^https?://", "", link) - - # 移除开头的双斜杠 link = link.lstrip("/") - parsed = urllib.parse.urlparse(f"http://{link}") normalized = urllib.parse.urlunparse( ("", parsed.netloc, parsed.path, parsed.params, parsed.query, "") ) result = normalized.rstrip("/") - logger.debug(f"Normalized link: {link} -> {result}") return result + def extract_domain(self, url): + parsed = urllib.parse.urlparse(url) + domain = parsed.netloc or parsed.path + domain = domain.split(':')[0] # Remove port if present + parts = domain.split('.') + if len(parts) > 2: + domain = '.'.join(parts[-2:]) + return domain.lower() def is_whitelisted(self, link): - extracted = tldextract.extract(link) - domain = f"{extracted.domain}.{extracted.suffix}" + domain = self.extract_domain(link) result = domain in self.whitelist logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}") return result @@ -84,44 +65,26 @@ class LinkFilter: def add_keyword(self, keyword): if self.link_pattern.match(keyword): keyword = self.normalize_link(keyword) - # 确保在这里去掉开头的双斜杠 keyword = keyword.lstrip("/") if keyword not in self.keywords: - self.keywords.append(keyword) - self.save_keywords() + self.db.add_keyword(keyword) logger.info(f"New keyword added: {keyword}") - self.load_data_from_file() # 重新加载文件 + self.load_data_from_file() else: logger.debug(f"Keyword already exists: {keyword}") def remove_keyword(self, keyword): if keyword in self.keywords: - self.keywords.remove(keyword) - self.save_keywords() - self.load_data_from_file() # 重新加载以确保数据同步 + self.db.remove_keyword(keyword) + self.load_data_from_file() return True return False def remove_keywords_containing(self, substring): - # 记录原始关键词列表的长度 - original_count = len(self.keywords) - - # 创建一个列表,包含所有需要移除的关键词 - removed_keywords = [ - kw for kw in self.keywords if substring.lower() in kw.lower() - ] - - # 修改关键词列表,仅保留不包含指定子字符串的关键词 - self.keywords = [ - kw for kw in self.keywords if substring.lower() not in kw.lower() - ] - - # 如果有关键词被移除,则保存关键词列表并重新加载数据 - if removed_keywords: - self.save_keywords() - self.load_data_from_file() - - # 返回被移除的关键词列表 + removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()] + for keyword in removed_keywords: + self.db.remove_keyword(keyword) + self.load_data_from_file() return removed_keywords def should_filter(self, text): @@ -135,7 +98,7 @@ class LinkFilter: new_non_whitelisted_links = [] for link in links: normalized_link = self.normalize_link(link) - normalized_link = normalized_link.lstrip("/") # 去除开头的双斜杠 + normalized_link = normalized_link.lstrip("/") if not self.is_whitelisted(normalized_link): logger.debug(f"Link not whitelisted: {normalized_link}") if normalized_link not in self.keywords: @@ -169,9 +132,7 @@ class LinkFilter: if self.remove_keyword(keyword): await event.reply(f"关键词 '{keyword}' 已删除。") else: - similar_keywords = [ - k for k in self.keywords if keyword.lower() in k.lower() - ] + similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()] if similar_keywords: await send_long_message( event, @@ -204,8 +165,7 @@ class LinkFilter: elif command == "/addwhite" and args: domain = args[0].lower() if domain not in self.whitelist: - self.whitelist.append(domain) - self.save_whitelist() + self.db.add_whitelist(domain) self.load_data_from_file() await event.reply(f"域名 '{domain}' 已添加到白名单。") else: @@ -213,8 +173,7 @@ class LinkFilter: elif command == "/delwhite" and args: domain = args[0].lower() if domain in self.whitelist: - self.whitelist.remove(domain) - self.save_whitelist() + self.db.remove_whitelist(domain) self.load_data_from_file() await event.reply(f"域名 '{domain}' 已从白名单中删除。") else: diff --git a/src/migrate.py b/src/migrate.py new file mode 100644 index 0000000..9169dc4 --- /dev/null +++ b/src/migrate.py @@ -0,0 +1,29 @@ +import json +import os +from database import Database + +def migrate_data(json_file, db_file): + # 确保 data 目录存在 + os.makedirs(os.path.dirname(db_file), exist_ok=True) + + # 创建数据库连接 + db = Database(db_file) + + # 读取 JSON 文件 + with open(json_file, 'r') as f: + data = json.load(f) + + # 迁移关键词 + for keyword in data.get('keywords', []): + db.add_keyword(keyword) + + # 迁移白名单 + for domain in data.get('whitelist', []): + db.add_whitelist(domain) + + print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}个") + +if __name__ == "__main__": + json_file = 'keywords.json' # 旧的 JSON 文件路径 + db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径 + migrate_data(json_file, db_file)