引入sqlite

This commit is contained in:
wood 2024-09-09 19:59:06 +08:00
parent c4c558ce1d
commit 8a2ade4ea4
5 changed files with 142 additions and 76 deletions

View File

@ -3,4 +3,3 @@ ccxt
pyTelegramBotAPI
schedule
pytz
tldextract

82
src/database.py Normal file
View File

@ -0,0 +1,82 @@
import sqlite3
import logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_file):
self.db_file = db_file
self.conn = None
self.create_tables()
def create_tables(self):
try:
self.conn = sqlite3.connect(self.db_file)
cursor = self.conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS keywords
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS whitelist
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
"""
)
self.conn.commit()
except sqlite3.Error as e:
logger.error(f"Database error: {e}")
finally:
if self.conn:
self.conn.close()
def execute_query(self, query, params=None):
try:
self.conn = sqlite3.connect(self.db_file)
cursor = self.conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
self.conn.commit()
return cursor
except sqlite3.Error as e:
logger.error(f"Query execution error: {e}")
return None
finally:
if self.conn:
self.conn.close()
def add_keyword(self, keyword):
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
self.execute_query(query, (keyword,))
def remove_keyword(self, keyword):
query = "DELETE FROM keywords WHERE keyword = ?"
self.execute_query(query, (keyword,))
def get_all_keywords(self):
query = "SELECT keyword FROM keywords"
cursor = self.execute_query(query)
return [row[0] for row in cursor] if cursor else []
def remove_keywords_containing(self, substring):
query = "DELETE FROM keywords WHERE keyword LIKE ?"
cursor = self.execute_query(query, (f"%{substring}%",))
return cursor.rowcount if cursor else 0
def add_whitelist(self, domain):
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
self.execute_query(query, (domain,))
def remove_whitelist(self, domain):
query = "DELETE FROM whitelist WHERE domain = ?"
self.execute_query(query, (domain,))
def get_all_whitelist(self):
query = "SELECT domain FROM whitelist"
cursor = self.execute_query(query)
return [row[0] for row in cursor] if cursor else []

View File

@ -8,13 +8,10 @@ import time
from link_filter import LinkFilter
from bot_commands import register_commands
# 环境变量
BOT_TOKEN = os.environ.get('BOT_TOKEN')
ADMIN_ID = int(os.environ.get('ADMIN_ID'))
KEYWORDS_FILE = '/app/data/keywords.json'
WHITELIST_FILE = '/app/data/whitelist.json'
DB_FILE = '/app/data/q58.db' # 新的数据库文件路径
# 设置日志
DEBUG_MODE = os.environ.get('DEBUG_MODE', 'False').lower() == 'true'
@ -30,7 +27,7 @@ link_filter_logger.setLevel(logging.DEBUG if DEBUG_MODE else logging.INFO)
logging.getLogger('telethon').setLevel(logging.WARNING)
# 创建 LinkFilter 实例
link_filter = LinkFilter(KEYWORDS_FILE, WHITELIST_FILE)
link_filter = LinkFilter(DB_FILE)
class RateLimiter:
def __init__(self, max_calls, period):
@ -73,7 +70,6 @@ async def process_message(event, client):
if new_links:
logger.info(f"New non-whitelisted links found: {new_links}")
async def message_handler(event, link_filter, rate_limiter):
if not event.is_private or event.sender_id != ADMIN_ID:
async with rate_limiter:
@ -94,6 +90,7 @@ async def command_handler(event, link_filter):
if event.raw_text.startswith(('/add', '/delete', '/deletecontaining','/list', '/addwhite', '/delwhite', '/listwhite')):
link_filter.load_data_from_file()
async def start_bot():
async with TelegramClient('bot', api_id=6, api_hash='eb06d4abfb49dc3eeb1aeb98ae0f581e') as client:
await client.start(bot_token=BOT_TOKEN)

View File

@ -1,20 +1,16 @@
import re
import json
import tldextract
import urllib.parse
import logging
from database import Database
from functions import send_long_message
logger = logging.getLogger("TeleGuard.LinkFilter")
class LinkFilter:
def __init__(self, keywords_file, whitelist_file):
self.keywords_file = keywords_file
self.whitelist_file = whitelist_file
self.keywords = []
self.whitelist = []
self.load_data_from_file()
def __init__(self, db_file):
self.db = Database(db_file)
self.keywords = self.db.get_all_keywords()
self.whitelist = self.db.get_all_whitelist()
self.link_pattern = re.compile(
r"""
@ -34,49 +30,34 @@ class LinkFilter:
re.VERBOSE | re.IGNORECASE,
)
def load_json(self, file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except FileNotFoundError:
return []
def save_json(self, file_path, data):
with open(file_path, "w") as f:
json.dump(data, f)
def save_keywords(self):
self.save_json(self.keywords_file, self.keywords)
def save_whitelist(self):
self.save_json(self.whitelist_file, self.whitelist)
def load_data_from_file(self):
self.keywords = self.load_json(self.keywords_file)
self.whitelist = self.load_json(self.whitelist_file)
self.keywords = self.db.get_all_keywords()
self.whitelist = self.db.get_all_whitelist()
logger.info(
f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries"
)
def normalize_link(self, link):
# 移除协议部分(如 http:// 或 https://
link = re.sub(r"^https?://", "", link)
# 移除开头的双斜杠
link = link.lstrip("/")
parsed = urllib.parse.urlparse(f"http://{link}")
normalized = urllib.parse.urlunparse(
("", parsed.netloc, parsed.path, parsed.params, parsed.query, "")
)
result = normalized.rstrip("/")
logger.debug(f"Normalized link: {link} -> {result}")
return result
def extract_domain(self, url):
parsed = urllib.parse.urlparse(url)
domain = parsed.netloc or parsed.path
domain = domain.split(':')[0] # Remove port if present
parts = domain.split('.')
if len(parts) > 2:
domain = '.'.join(parts[-2:])
return domain.lower()
def is_whitelisted(self, link):
extracted = tldextract.extract(link)
domain = f"{extracted.domain}.{extracted.suffix}"
domain = self.extract_domain(link)
result = domain in self.whitelist
logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}")
return result
@ -84,44 +65,26 @@ class LinkFilter:
def add_keyword(self, keyword):
if self.link_pattern.match(keyword):
keyword = self.normalize_link(keyword)
# 确保在这里去掉开头的双斜杠
keyword = keyword.lstrip("/")
if keyword not in self.keywords:
self.keywords.append(keyword)
self.save_keywords()
self.db.add_keyword(keyword)
logger.info(f"New keyword added: {keyword}")
self.load_data_from_file() # 重新加载文件
self.load_data_from_file()
else:
logger.debug(f"Keyword already exists: {keyword}")
def remove_keyword(self, keyword):
if keyword in self.keywords:
self.keywords.remove(keyword)
self.save_keywords()
self.load_data_from_file() # 重新加载以确保数据同步
self.db.remove_keyword(keyword)
self.load_data_from_file()
return True
return False
def remove_keywords_containing(self, substring):
# 记录原始关键词列表的长度
original_count = len(self.keywords)
# 创建一个列表,包含所有需要移除的关键词
removed_keywords = [
kw for kw in self.keywords if substring.lower() in kw.lower()
]
# 修改关键词列表,仅保留不包含指定子字符串的关键词
self.keywords = [
kw for kw in self.keywords if substring.lower() not in kw.lower()
]
# 如果有关键词被移除,则保存关键词列表并重新加载数据
if removed_keywords:
self.save_keywords()
self.load_data_from_file()
# 返回被移除的关键词列表
removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()]
for keyword in removed_keywords:
self.db.remove_keyword(keyword)
self.load_data_from_file()
return removed_keywords
def should_filter(self, text):
@ -135,7 +98,7 @@ class LinkFilter:
new_non_whitelisted_links = []
for link in links:
normalized_link = self.normalize_link(link)
normalized_link = normalized_link.lstrip("/") # 去除开头的双斜杠
normalized_link = normalized_link.lstrip("/")
if not self.is_whitelisted(normalized_link):
logger.debug(f"Link not whitelisted: {normalized_link}")
if normalized_link not in self.keywords:
@ -169,9 +132,7 @@ class LinkFilter:
if self.remove_keyword(keyword):
await event.reply(f"关键词 '{keyword}' 已删除。")
else:
similar_keywords = [
k for k in self.keywords if keyword.lower() in k.lower()
]
similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()]
if similar_keywords:
await send_long_message(
event,
@ -204,8 +165,7 @@ class LinkFilter:
elif command == "/addwhite" and args:
domain = args[0].lower()
if domain not in self.whitelist:
self.whitelist.append(domain)
self.save_whitelist()
self.db.add_whitelist(domain)
self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已添加到白名单。")
else:
@ -213,8 +173,7 @@ class LinkFilter:
elif command == "/delwhite" and args:
domain = args[0].lower()
if domain in self.whitelist:
self.whitelist.remove(domain)
self.save_whitelist()
self.db.remove_whitelist(domain)
self.load_data_from_file()
await event.reply(f"域名 '{domain}' 已从白名单中删除。")
else:

29
src/migrate.py Normal file
View File

@ -0,0 +1,29 @@
import json
import os
from database import Database
def migrate_data(json_file, db_file):
# 确保 data 目录存在
os.makedirs(os.path.dirname(db_file), exist_ok=True)
# 创建数据库连接
db = Database(db_file)
# 读取 JSON 文件
with open(json_file, 'r') as f:
data = json.load(f)
# 迁移关键词
for keyword in data.get('keywords', []):
db.add_keyword(keyword)
# 迁移白名单
for domain in data.get('whitelist', []):
db.add_whitelist(domain)
print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}")
if __name__ == "__main__":
json_file = 'keywords.json' # 旧的 JSON 文件路径
db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径
migrate_data(json_file, db_file)