diff --git a/src/link_filter.py b/src/link_filter.py index 88d9283..c33c4a4 100644 --- a/src/link_filter.py +++ b/src/link_filter.py @@ -7,10 +7,10 @@ class LinkFilter: def __init__(self, keywords_file, whitelist_file): self.keywords_file = keywords_file self.whitelist_file = whitelist_file - self.keywords = self.load_json(keywords_file) - self.whitelist = self.load_json(whitelist_file) + self.keywords = [] + self.whitelist = [] + self.load_data_from_file() - # 正则表达式匹配各种链接格式 self.link_pattern = re.compile(r''' \b (?: @@ -21,7 +21,7 @@ class LinkFilter: | # or (?:t\.me|telegram\.me) # Telegram links ) - (?:/[^\s]*)? # optional path + (?:/[^\s]*)? # optional path and query string ) \b ''', re.VERBOSE | re.IGNORECASE) @@ -33,36 +33,30 @@ class LinkFilter: except FileNotFoundError: return [] + def save_json(self, file_path, data): + with open(file_path, 'w') as f: + json.dump(data, f) + def save_keywords(self): - with open(self.keywords_file, 'w') as f: - json.dump(self.keywords, f) + self.save_json(self.keywords_file, self.keywords) + + def save_whitelist(self): + self.save_json(self.whitelist_file, self.whitelist) + + def load_data_from_file(self): + self.keywords = self.load_json(self.keywords_file) + self.whitelist = self.load_json(self.whitelist_file) + + def normalize_link(self, link): + link = re.sub(r'^https?://', '', link) + parsed = urllib.parse.urlparse(f"http://{link}") + return urllib.parse.urlunparse(('', parsed.netloc, parsed.path, parsed.params, parsed.query, '')).rstrip('/') def is_whitelisted(self, link): extracted = tldextract.extract(link) domain = f"{extracted.domain}.{extracted.suffix}" return domain in self.whitelist - - def normalize_link(self, link): - # 解析链接 - parsed = urllib.parse.urlparse(link) - - # 如果没有 scheme,添加 'https://' - if not parsed.scheme: - link = 'https://' + link - parsed = urllib.parse.urlparse(link) - - # 重新组合链接,去除查询参数 - normalized = urllib.parse.urlunparse(( - parsed.scheme, - parsed.netloc, - parsed.path, - '', - '', - '' - )) - - return normalized.rstrip('/') # 移除尾部的斜杠 def add_keyword(self, keyword): if self.link_pattern.match(keyword): keyword = self.normalize_link(keyword) @@ -71,16 +65,16 @@ class LinkFilter: self.save_keywords() def remove_keyword(self, keyword): + if self.link_pattern.match(keyword): + keyword = self.normalize_link(keyword) if keyword in self.keywords: self.keywords.remove(keyword) self.save_keywords() return True return False - def should_filter(self, text): - # 检查是否包含关键词 - if any(keyword.lower() in text.lower() for keyword in self.keywords if not self.link_pattern.match(keyword)): + if any(keyword.lower() in text.lower() for keyword in self.keywords): return True, [] links = self.link_pattern.findall(text) @@ -92,16 +86,6 @@ class LinkFilter: new_non_whitelisted_links.append(normalized_link) self.add_keyword(normalized_link) else: - return True, [] # 如果找到已存在的非白名单链接,应该过滤 + return True, [] return False, new_non_whitelisted_links - - def reload_keywords(self): - self.keywords = self.load_json(self.keywords_file) - - def reload_whitelist(self): - self.whitelist = self.load_json(self.whitelist_file) - - def save_whitelist(self): - with open(self.whitelist_file, 'w') as f: - json.dump(self.whitelist, f)