引入sqlite

This commit is contained in:
wood 2024-09-09 19:59:06 +08:00
parent c4c558ce1d
commit 8a2ade4ea4
5 changed files with 142 additions and 76 deletions

View File

@ -3,4 +3,3 @@ ccxt
pyTelegramBotAPI pyTelegramBotAPI
schedule schedule
pytz pytz
tldextract

82
src/database.py Normal file
View File

@ -0,0 +1,82 @@
import sqlite3
import logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_file):
self.db_file = db_file
self.conn = None
self.create_tables()
def create_tables(self):
try:
self.conn = sqlite3.connect(self.db_file)
cursor = self.conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS keywords
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS whitelist
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
"""
)
self.conn.commit()
except sqlite3.Error as e:
logger.error(f"Database error: {e}")
finally:
if self.conn:
self.conn.close()
def execute_query(self, query, params=None):
try:
self.conn = sqlite3.connect(self.db_file)
cursor = self.conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
self.conn.commit()
return cursor
except sqlite3.Error as e:
logger.error(f"Query execution error: {e}")
return None
finally:
if self.conn:
self.conn.close()
def add_keyword(self, keyword):
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
self.execute_query(query, (keyword,))
def remove_keyword(self, keyword):
query = "DELETE FROM keywords WHERE keyword = ?"
self.execute_query(query, (keyword,))
def get_all_keywords(self):
query = "SELECT keyword FROM keywords"
cursor = self.execute_query(query)
return [row[0] for row in cursor] if cursor else []
def remove_keywords_containing(self, substring):
query = "DELETE FROM keywords WHERE keyword LIKE ?"
cursor = self.execute_query(query, (f"%{substring}%",))
return cursor.rowcount if cursor else 0
def add_whitelist(self, domain):
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
self.execute_query(query, (domain,))
def remove_whitelist(self, domain):
query = "DELETE FROM whitelist WHERE domain = ?"
self.execute_query(query, (domain,))
def get_all_whitelist(self):
query = "SELECT domain FROM whitelist"
cursor = self.execute_query(query)
return [row[0] for row in cursor] if cursor else []

View File

@ -8,13 +8,10 @@ import time
from link_filter import LinkFilter from link_filter import LinkFilter
from bot_commands import register_commands from bot_commands import register_commands
# 环境变量 # 环境变量
BOT_TOKEN = os.environ.get('BOT_TOKEN') BOT_TOKEN = os.environ.get('BOT_TOKEN')
ADMIN_ID = int(os.environ.get('ADMIN_ID')) ADMIN_ID = int(os.environ.get('ADMIN_ID'))
KEYWORDS_FILE = '/app/data/keywords.json' DB_FILE = '/app/data/q58.db' # 新的数据库文件路径
WHITELIST_FILE = '/app/data/whitelist.json'
# 设置日志 # 设置日志
DEBUG_MODE = os.environ.get('DEBUG_MODE', 'False').lower() == 'true' DEBUG_MODE = os.environ.get('DEBUG_MODE', 'False').lower() == 'true'
@ -30,7 +27,7 @@ link_filter_logger.setLevel(logging.DEBUG if DEBUG_MODE else logging.INFO)
logging.getLogger('telethon').setLevel(logging.WARNING) logging.getLogger('telethon').setLevel(logging.WARNING)
# 创建 LinkFilter 实例 # 创建 LinkFilter 实例
link_filter = LinkFilter(KEYWORDS_FILE, WHITELIST_FILE) link_filter = LinkFilter(DB_FILE)
class RateLimiter: class RateLimiter:
def __init__(self, max_calls, period): def __init__(self, max_calls, period):
@ -73,7 +70,6 @@ async def process_message(event, client):
if new_links: if new_links:
logger.info(f"New non-whitelisted links found: {new_links}") logger.info(f"New non-whitelisted links found: {new_links}")
async def message_handler(event, link_filter, rate_limiter): async def message_handler(event, link_filter, rate_limiter):
if not event.is_private or event.sender_id != ADMIN_ID: if not event.is_private or event.sender_id != ADMIN_ID:
async with rate_limiter: async with rate_limiter:
@ -94,6 +90,7 @@ async def command_handler(event, link_filter):
if event.raw_text.startswith(('/add', '/delete', '/deletecontaining','/list', '/addwhite', '/delwhite', '/listwhite')): if event.raw_text.startswith(('/add', '/delete', '/deletecontaining','/list', '/addwhite', '/delwhite', '/listwhite')):
link_filter.load_data_from_file() link_filter.load_data_from_file()
async def start_bot(): async def start_bot():
async with TelegramClient('bot', api_id=6, api_hash='eb06d4abfb49dc3eeb1aeb98ae0f581e') as client: async with TelegramClient('bot', api_id=6, api_hash='eb06d4abfb49dc3eeb1aeb98ae0f581e') as client:
await client.start(bot_token=BOT_TOKEN) await client.start(bot_token=BOT_TOKEN)

View File

@ -1,20 +1,16 @@
import re import re
import json
import tldextract
import urllib.parse import urllib.parse
import logging import logging
from database import Database
from functions import send_long_message from functions import send_long_message
logger = logging.getLogger("TeleGuard.LinkFilter") logger = logging.getLogger("TeleGuard.LinkFilter")
class LinkFilter: class LinkFilter:
def __init__(self, keywords_file, whitelist_file): def __init__(self, db_file):
self.keywords_file = keywords_file self.db = Database(db_file)
self.whitelist_file = whitelist_file self.keywords = self.db.get_all_keywords()
self.keywords = [] self.whitelist = self.db.get_all_whitelist()
self.whitelist = []
self.load_data_from_file()
self.link_pattern = re.compile( self.link_pattern = re.compile(
r""" r"""
@ -34,49 +30,34 @@ class LinkFilter:
re.VERBOSE | re.IGNORECASE, re.VERBOSE | re.IGNORECASE,
) )
def load_json(self, file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except FileNotFoundError:
return []
def save_json(self, file_path, data):
with open(file_path, "w") as f:
json.dump(data, f)
def save_keywords(self):
self.save_json(self.keywords_file, self.keywords)
def save_whitelist(self):
self.save_json(self.whitelist_file, self.whitelist)
def load_data_from_file(self): def load_data_from_file(self):
self.keywords = self.load_json(self.keywords_file) self.keywords = self.db.get_all_keywords()
self.whitelist = self.load_json(self.whitelist_file) self.whitelist = self.db.get_all_whitelist()
logger.info( logger.info(
f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries" f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries"
) )
def normalize_link(self, link): def normalize_link(self, link):
# 移除协议部分(如 http:// 或 https://
link = re.sub(r"^https?://", "", link) link = re.sub(r"^https?://", "", link)
# 移除开头的双斜杠
link = link.lstrip("/") link = link.lstrip("/")
parsed = urllib.parse.urlparse(f"http://{link}") parsed = urllib.parse.urlparse(f"http://{link}")
normalized = urllib.parse.urlunparse( normalized = urllib.parse.urlunparse(
("", parsed.netloc, parsed.path, parsed.params, parsed.query, "") ("", parsed.netloc, parsed.path, parsed.params, parsed.query, "")
) )
result = normalized.rstrip("/") result = normalized.rstrip("/")
logger.debug(f"Normalized link: {link} -> {result}") logger.debug(f"Normalized link: {link} -> {result}")
return result return result
def extract_domain(self, url):
parsed = urllib.parse.urlparse(url)
domain = parsed.netloc or parsed.path
domain = domain.split(':')[0] # Remove port if present
parts = domain.split('.')
if len(parts) > 2:
domain = '.'.join(parts[-2:])
return domain.lower()
def is_whitelisted(self, link): def is_whitelisted(self, link):
extracted = tldextract.extract(link) domain = self.extract_domain(link)
domain = f"{extracted.domain}.{extracted.suffix}"
result = domain in self.whitelist result = domain in self.whitelist
logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}") logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}")
return result return result
@ -84,44 +65,26 @@ class LinkFilter:
def add_keyword(self, keyword): def add_keyword(self, keyword):
if self.link_pattern.match(keyword): if self.link_pattern.match(keyword):
keyword = self.normalize_link(keyword) keyword = self.normalize_link(keyword)
# 确保在这里去掉开头的双斜杠
keyword = keyword.lstrip("/") keyword = keyword.lstrip("/")
if keyword not in self.keywords: if keyword not in self.keywords:
self.keywords.append(keyword) self.db.add_keyword(keyword)
self.save_keywords()
logger.info(f"New keyword added: {keyword}") logger.info(f"New keyword added: {keyword}")
self.load_data_from_file() # 重新加载文件 self.load_data_from_file()
else: else:
logger.debug(f"Keyword already exists: {keyword}") logger.debug(f"Keyword already exists: {keyword}")
def remove_keyword(self, keyword): def remove_keyword(self, keyword):
if keyword in self.keywords: if keyword in self.keywords:
self.keywords.remove(keyword) self.db.remove_keyword(keyword)
self.save_keywords() self.load_data_from_file()
self.load_data_from_file() # 重新加载以确保数据同步
return True return True
return False return False
def remove_keywords_containing(self, substring): def remove_keywords_containing(self, substring):
# 记录原始关键词列表的长度 removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()]
original_count = len(self.keywords) for keyword in removed_keywords:
self.db.remove_keyword(keyword)
# 创建一个列表,包含所有需要移除的关键词
removed_keywords = [
kw for kw in self.keywords if substring.lower() in kw.lower()
]
# 修改关键词列表,仅保留不包含指定子字符串的关键词
self.keywords = [
kw for kw in self.keywords if substring.lower() not in kw.lower()
]
# 如果有关键词被移除,则保存关键词列表并重新加载数据
if removed_keywords:
self.save_keywords()
self.load_data_from_file() self.load_data_from_file()
# 返回被移除的关键词列表
return removed_keywords return removed_keywords
def should_filter(self, text): def should_filter(self, text):
@ -135,7 +98,7 @@ class LinkFilter:
new_non_whitelisted_links = [] new_non_whitelisted_links = []
for link in links: for link in links:
normalized_link = self.normalize_link(link) normalized_link = self.normalize_link(link)
normalized_link = normalized_link.lstrip("/") # 去除开头的双斜杠 normalized_link = normalized_link.lstrip("/")
if not self.is_whitelisted(normalized_link): if not self.is_whitelisted(normalized_link):
logger.debug(f"Link not whitelisted: {normalized_link}") logger.debug(f"Link not whitelisted: {normalized_link}")
if normalized_link not in self.keywords: if normalized_link not in self.keywords:
@ -169,9 +132,7 @@ class LinkFilter:
if self.remove_keyword(keyword): if self.remove_keyword(keyword):
await event.reply(f"关键词 '{keyword}' 已删除。") await event.reply(f"关键词 '{keyword}' 已删除。")
else: else:
similar_keywords = [ similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()]
k for k in self.keywords if keyword.lower() in k.lower()
]
if similar_keywords: if similar_keywords:
await send_long_message( await send_long_message(
event, event,
@ -204,8 +165,7 @@ class LinkFilter:
elif command == "/addwhite" and args: elif command == "/addwhite" and args:
domain = args[0].lower() domain = args[0].lower()
if domain not in self.whitelist: if domain not in self.whitelist:
self.whitelist.append(domain) self.db.add_whitelist(domain)
self.save_whitelist()
self.load_data_from_file() self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已添加到白名单。") await event.reply(f"域名 '{domain}' 已添加到白名单。")
else: else:
@ -213,8 +173,7 @@ class LinkFilter:
elif command == "/delwhite" and args: elif command == "/delwhite" and args:
domain = args[0].lower() domain = args[0].lower()
if domain in self.whitelist: if domain in self.whitelist:
self.whitelist.remove(domain) self.db.remove_whitelist(domain)
self.save_whitelist()
self.load_data_from_file() self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已从白名单中删除。") await event.reply(f"域名 '{domain}' 已从白名单中删除。")
else: else:

29
src/migrate.py Normal file
View File

@ -0,0 +1,29 @@
import json
import os
from database import Database
def migrate_data(json_file, db_file):
# 确保 data 目录存在
os.makedirs(os.path.dirname(db_file), exist_ok=True)
# 创建数据库连接
db = Database(db_file)
# 读取 JSON 文件
with open(json_file, 'r') as f:
data = json.load(f)
# 迁移关键词
for keyword in data.get('keywords', []):
db.add_keyword(keyword)
# 迁移白名单
for domain in data.get('whitelist', []):
db.add_whitelist(domain)
print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}")
if __name__ == "__main__":
json_file = 'keywords.json' # 旧的 JSON 文件路径
db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径
migrate_data(json_file, db_file)