diff --git a/src/bot_commands.py b/src/bot_commands.py index d952c2f..60697cc 100644 --- a/src/bot_commands.py +++ b/src/bot_commands.py @@ -1,13 +1,16 @@ import os -import json from telethon.tl.types import InputPeerUser from telethon.tl.functions.bots import SetBotCommandsRequest from telethon.tl.types import BotCommand +from link_filter import LinkFilter KEYWORDS_FILE = '/app/data/keywords.json' WHITELIST_FILE = '/app/data/whitelist.json' ADMIN_ID = int(os.environ.get('ADMIN_ID')) +# 创建 LinkFilter 实例 +link_filter = LinkFilter(KEYWORDS_FILE, WHITELIST_FILE) + async def register_commands(client, admin_id): commands = [ BotCommand('add', '添加新的关键词'), @@ -27,17 +30,6 @@ async def register_commands(client, admin_id): print("Bot commands registered successfully.") except Exception as e: print(f"Failed to register bot commands: {str(e)}") - -def load_json(file_path): - try: - with open(file_path, 'r') as f: - return json.load(f) - except FileNotFoundError: - return [] - -def save_json(file_path, data): - with open(file_path, 'w') as f: - json.dump(data, f) async def handle_command(event, client): sender = await event.get_sender() @@ -53,23 +45,21 @@ async def handle_command(event, client): await handle_whitelist_command(event, command, args) async def handle_keyword_command(event, command, args): - keywords = load_json(KEYWORDS_FILE) - if command == '/list': + keywords = link_filter.keywords await event.reply("当前关键词列表:\n" + "\n".join(keywords) if keywords else "关键词列表为空。") elif command == '/add' and args: - keyword = args[0].lower() - if keyword not in keywords: - keywords.append(keyword) - save_json(KEYWORDS_FILE, keywords) + keyword = args[0] + normalized_keyword = link_filter.normalize_link(keyword) if link_filter.link_pattern.match(keyword) else keyword.lower() + if normalized_keyword not in link_filter.keywords: + link_filter.add_keyword(normalized_keyword) await event.reply(f"关键词 '{keyword}' 已添加。") else: await event.reply(f"关键词 '{keyword}' 已存在。") elif command == '/delete' and args: - keyword = args[0].lower() - if keyword in keywords: - keywords.remove(keyword) - save_json(KEYWORDS_FILE, keywords) + keyword = args[0] + normalized_keyword = link_filter.normalize_link(keyword) if link_filter.link_pattern.match(keyword) else keyword.lower() + if link_filter.remove_keyword(normalized_keyword): await event.reply(f"关键词 '{keyword}' 已删除。") else: await event.reply(f"关键词 '{keyword}' 不存在。") @@ -77,23 +67,22 @@ async def handle_keyword_command(event, command, args): await event.reply("无效的命令或参数。") async def handle_whitelist_command(event, command, args): - whitelist = load_json(WHITELIST_FILE) - if command == '/listwhite': + whitelist = link_filter.whitelist await event.reply("白名单域名列表:\n" + "\n".join(whitelist) if whitelist else "白名单为空。") elif command == '/addwhite' and args: domain = args[0].lower() - if domain not in whitelist: - whitelist.append(domain) - save_json(WHITELIST_FILE, whitelist) + if domain not in link_filter.whitelist: + link_filter.whitelist.append(domain) + link_filter.save_whitelist() await event.reply(f"域名 '{domain}' 已添加到白名单。") else: await event.reply(f"域名 '{domain}' 已在白名单中。") elif command == '/delwhite' and args: domain = args[0].lower() - if domain in whitelist: - whitelist.remove(domain) - save_json(WHITELIST_FILE, whitelist) + if domain in link_filter.whitelist: + link_filter.whitelist.remove(domain) + link_filter.save_whitelist() await event.reply(f"域名 '{domain}' 已从白名单中删除。") else: await event.reply(f"域名 '{domain}' 不在白名单中。") @@ -101,12 +90,9 @@ async def handle_whitelist_command(event, command, args): await event.reply("无效的命令或参数。") def get_keywords(): - return load_json(KEYWORDS_FILE) + return link_filter.keywords def get_whitelist(): - return load_json(WHITELIST_FILE) - - + return link_filter.whitelist __all__ = ['handle_command', 'get_keywords', 'get_whitelist', 'register_commands'] - diff --git a/src/guard.py b/src/guard.py index e02282d..f56c97f 100644 --- a/src/guard.py +++ b/src/guard.py @@ -60,21 +60,23 @@ async def delete_message_after_delay(client, chat, message, delay): # 处理消息函数 async def process_message(event, client): if not event.is_private: - # 检查消息是否包含已知的关键词(包括之前添加的非白名单链接) - if any(keyword in event.message.text for keyword in link_filter.keywords): + # 检查消息是否应该被过滤 + should_filter, new_links = link_filter.should_filter(event.message.text) + + if should_filter: if event.sender_id != ADMIN_ID: await event.delete() - notification = await event.respond("已撤回该消息。注:重复发送的推广链接会被自动撤回。") + notification = await event.respond("已撤回该消息。注:包含关键词或重复发送的非白名单链接会被自动撤回。") asyncio.create_task(delete_message_after_delay(client, event.chat_id, notification, 3 * 60)) return - # 检查是否有新的非白名单链接 - new_links = link_filter.should_filter(event.message.text) if new_links: # 这是第一次发送这些非白名单链接,我们允许消息通过,不发送任何警告 + # 如果需要,可以在这里添加日志记录或其他操作 pass + async def command_handler(event): if event.is_private and event.sender_id == ADMIN_ID: await handle_command(event, event.client) diff --git a/src/link_filter.py b/src/link_filter.py index 7eaf91e..b0cfe1e 100644 --- a/src/link_filter.py +++ b/src/link_filter.py @@ -1,6 +1,7 @@ import re import json import tldextract +import urllib.parse class LinkFilter: def __init__(self, keywords_file, whitelist_file): @@ -38,42 +39,70 @@ class LinkFilter: def is_whitelisted(self, link): extracted = tldextract.extract(link) - full_domain = '.'.join(part for part in [extracted.subdomain, extracted.domain, extracted.suffix] if part) - main_domain = f"{extracted.domain}.{extracted.suffix}" - - # 检查完整域名(包括子域名) - if full_domain in self.whitelist: - return True - - # 检查主域名 - if main_domain in self.whitelist: - return True - - # 检查是否有通配符匹配 - wildcard_domain = f"*.{main_domain}" - if wildcard_domain in self.whitelist: - return True - - return False + domain = f"{extracted.domain}.{extracted.suffix}" + return domain in self.whitelist - def add_keyword(self, link): - if link not in self.keywords: - self.keywords.append(link) + def normalize_link(self, link): + # 解析链接 + parsed = urllib.parse.urlparse(link) + + # 如果没有 scheme,添加 'https://' + if not parsed.scheme: + link = 'https://' + link + parsed = urllib.parse.urlparse(link) + + # 重新组合链接,去除查询参数 + normalized = urllib.parse.urlunparse(( + parsed.scheme, + parsed.netloc, + parsed.path, + '', + '', + '' + )) + + return normalized.rstrip('/') # 移除尾部的斜杠 + def add_keyword(self, keyword): + if self.link_pattern.match(keyword): + keyword = self.normalize_link(keyword) + if keyword not in self.keywords: + self.keywords.append(keyword) self.save_keywords() + def remove_keyword(self, keyword): + if self.link_pattern.match(keyword): + keyword = self.normalize_link(keyword) + if keyword in self.keywords: + self.keywords.remove(keyword) + self.save_keywords() + return True + return False + def should_filter(self, text): + # 检查是否包含关键词 + if any(keyword.lower() in text.lower() for keyword in self.keywords if not self.link_pattern.match(keyword)): + return True, [] + links = self.link_pattern.findall(text) new_non_whitelisted_links = [] for link in links: - if not self.is_whitelisted(link): - if link not in self.keywords: - new_non_whitelisted_links.append(link) - self.add_keyword(link) - return new_non_whitelisted_links + normalized_link = self.normalize_link(link) + if not self.is_whitelisted(normalized_link): + if normalized_link not in self.keywords: + new_non_whitelisted_links.append(normalized_link) + self.add_keyword(normalized_link) + else: + return True, [] # 如果找到已存在的非白名单链接,应该过滤 + + return False, new_non_whitelisted_links def reload_keywords(self): self.keywords = self.load_json(self.keywords_file) def reload_whitelist(self): self.whitelist = self.load_json(self.whitelist_file) + + def save_whitelist(self): + with open(self.whitelist_file, 'w') as f: + json.dump(self.whitelist, f)