mirror of
https://github.com/woodchen-ink/Q58Bot.git
synced 2025-07-18 13:52:07 +08:00
引入sqlite
This commit is contained in:
parent
c4c558ce1d
commit
8a2ade4ea4
@ -3,4 +3,3 @@ ccxt
|
||||
pyTelegramBotAPI
|
||||
schedule
|
||||
pytz
|
||||
tldextract
|
||||
|
82
src/database.py
Normal file
82
src/database.py
Normal file
@ -0,0 +1,82 @@
|
||||
import sqlite3
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_file):
|
||||
self.db_file = db_file
|
||||
self.conn = None
|
||||
self.create_tables()
|
||||
|
||||
def create_tables(self):
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_file)
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS keywords
|
||||
(id INTEGER PRIMARY KEY, keyword TEXT UNIQUE)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS whitelist
|
||||
(id INTEGER PRIMARY KEY, domain TEXT UNIQUE)
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Database error: {e}")
|
||||
finally:
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
def execute_query(self, query, params=None):
|
||||
try:
|
||||
self.conn = sqlite3.connect(self.db_file)
|
||||
cursor = self.conn.cursor()
|
||||
if params:
|
||||
cursor.execute(query, params)
|
||||
else:
|
||||
cursor.execute(query)
|
||||
self.conn.commit()
|
||||
return cursor
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Query execution error: {e}")
|
||||
return None
|
||||
finally:
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
def add_keyword(self, keyword):
|
||||
query = "INSERT OR IGNORE INTO keywords (keyword) VALUES (?)"
|
||||
self.execute_query(query, (keyword,))
|
||||
|
||||
def remove_keyword(self, keyword):
|
||||
query = "DELETE FROM keywords WHERE keyword = ?"
|
||||
self.execute_query(query, (keyword,))
|
||||
|
||||
def get_all_keywords(self):
|
||||
query = "SELECT keyword FROM keywords"
|
||||
cursor = self.execute_query(query)
|
||||
return [row[0] for row in cursor] if cursor else []
|
||||
|
||||
def remove_keywords_containing(self, substring):
|
||||
query = "DELETE FROM keywords WHERE keyword LIKE ?"
|
||||
cursor = self.execute_query(query, (f"%{substring}%",))
|
||||
return cursor.rowcount if cursor else 0
|
||||
|
||||
def add_whitelist(self, domain):
|
||||
query = "INSERT OR IGNORE INTO whitelist (domain) VALUES (?)"
|
||||
self.execute_query(query, (domain,))
|
||||
|
||||
def remove_whitelist(self, domain):
|
||||
query = "DELETE FROM whitelist WHERE domain = ?"
|
||||
self.execute_query(query, (domain,))
|
||||
|
||||
def get_all_whitelist(self):
|
||||
query = "SELECT domain FROM whitelist"
|
||||
cursor = self.execute_query(query)
|
||||
return [row[0] for row in cursor] if cursor else []
|
@ -8,13 +8,10 @@ import time
|
||||
from link_filter import LinkFilter
|
||||
from bot_commands import register_commands
|
||||
|
||||
|
||||
|
||||
# 环境变量
|
||||
BOT_TOKEN = os.environ.get('BOT_TOKEN')
|
||||
ADMIN_ID = int(os.environ.get('ADMIN_ID'))
|
||||
KEYWORDS_FILE = '/app/data/keywords.json'
|
||||
WHITELIST_FILE = '/app/data/whitelist.json'
|
||||
DB_FILE = '/app/data/q58.db' # 新的数据库文件路径
|
||||
|
||||
# 设置日志
|
||||
DEBUG_MODE = os.environ.get('DEBUG_MODE', 'False').lower() == 'true'
|
||||
@ -30,7 +27,7 @@ link_filter_logger.setLevel(logging.DEBUG if DEBUG_MODE else logging.INFO)
|
||||
logging.getLogger('telethon').setLevel(logging.WARNING)
|
||||
|
||||
# 创建 LinkFilter 实例
|
||||
link_filter = LinkFilter(KEYWORDS_FILE, WHITELIST_FILE)
|
||||
link_filter = LinkFilter(DB_FILE)
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, max_calls, period):
|
||||
@ -73,7 +70,6 @@ async def process_message(event, client):
|
||||
if new_links:
|
||||
logger.info(f"New non-whitelisted links found: {new_links}")
|
||||
|
||||
|
||||
async def message_handler(event, link_filter, rate_limiter):
|
||||
if not event.is_private or event.sender_id != ADMIN_ID:
|
||||
async with rate_limiter:
|
||||
@ -94,6 +90,7 @@ async def command_handler(event, link_filter):
|
||||
|
||||
if event.raw_text.startswith(('/add', '/delete', '/deletecontaining','/list', '/addwhite', '/delwhite', '/listwhite')):
|
||||
link_filter.load_data_from_file()
|
||||
|
||||
async def start_bot():
|
||||
async with TelegramClient('bot', api_id=6, api_hash='eb06d4abfb49dc3eeb1aeb98ae0f581e') as client:
|
||||
await client.start(bot_token=BOT_TOKEN)
|
||||
|
@ -1,20 +1,16 @@
|
||||
import re
|
||||
import json
|
||||
import tldextract
|
||||
import urllib.parse
|
||||
import logging
|
||||
from database import Database
|
||||
from functions import send_long_message
|
||||
|
||||
logger = logging.getLogger("TeleGuard.LinkFilter")
|
||||
|
||||
|
||||
class LinkFilter:
|
||||
def __init__(self, keywords_file, whitelist_file):
|
||||
self.keywords_file = keywords_file
|
||||
self.whitelist_file = whitelist_file
|
||||
self.keywords = []
|
||||
self.whitelist = []
|
||||
self.load_data_from_file()
|
||||
def __init__(self, db_file):
|
||||
self.db = Database(db_file)
|
||||
self.keywords = self.db.get_all_keywords()
|
||||
self.whitelist = self.db.get_all_whitelist()
|
||||
|
||||
self.link_pattern = re.compile(
|
||||
r"""
|
||||
@ -34,49 +30,34 @@ class LinkFilter:
|
||||
re.VERBOSE | re.IGNORECASE,
|
||||
)
|
||||
|
||||
def load_json(self, file_path):
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
def save_json(self, file_path, data):
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def save_keywords(self):
|
||||
self.save_json(self.keywords_file, self.keywords)
|
||||
|
||||
def save_whitelist(self):
|
||||
self.save_json(self.whitelist_file, self.whitelist)
|
||||
|
||||
def load_data_from_file(self):
|
||||
self.keywords = self.load_json(self.keywords_file)
|
||||
self.whitelist = self.load_json(self.whitelist_file)
|
||||
self.keywords = self.db.get_all_keywords()
|
||||
self.whitelist = self.db.get_all_whitelist()
|
||||
logger.info(
|
||||
f"Reloaded {len(self.keywords)} keywords and {len(self.whitelist)} whitelist entries"
|
||||
)
|
||||
|
||||
def normalize_link(self, link):
|
||||
# 移除协议部分(如 http:// 或 https://)
|
||||
link = re.sub(r"^https?://", "", link)
|
||||
|
||||
# 移除开头的双斜杠
|
||||
link = link.lstrip("/")
|
||||
|
||||
parsed = urllib.parse.urlparse(f"http://{link}")
|
||||
normalized = urllib.parse.urlunparse(
|
||||
("", parsed.netloc, parsed.path, parsed.params, parsed.query, "")
|
||||
)
|
||||
result = normalized.rstrip("/")
|
||||
|
||||
logger.debug(f"Normalized link: {link} -> {result}")
|
||||
return result
|
||||
|
||||
def extract_domain(self, url):
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
domain = parsed.netloc or parsed.path
|
||||
domain = domain.split(':')[0] # Remove port if present
|
||||
parts = domain.split('.')
|
||||
if len(parts) > 2:
|
||||
domain = '.'.join(parts[-2:])
|
||||
return domain.lower()
|
||||
def is_whitelisted(self, link):
|
||||
extracted = tldextract.extract(link)
|
||||
domain = f"{extracted.domain}.{extracted.suffix}"
|
||||
domain = self.extract_domain(link)
|
||||
result = domain in self.whitelist
|
||||
logger.debug(f"Whitelist check for {link}: {'Passed' if result else 'Failed'}")
|
||||
return result
|
||||
@ -84,44 +65,26 @@ class LinkFilter:
|
||||
def add_keyword(self, keyword):
|
||||
if self.link_pattern.match(keyword):
|
||||
keyword = self.normalize_link(keyword)
|
||||
# 确保在这里去掉开头的双斜杠
|
||||
keyword = keyword.lstrip("/")
|
||||
if keyword not in self.keywords:
|
||||
self.keywords.append(keyword)
|
||||
self.save_keywords()
|
||||
self.db.add_keyword(keyword)
|
||||
logger.info(f"New keyword added: {keyword}")
|
||||
self.load_data_from_file() # 重新加载文件
|
||||
self.load_data_from_file()
|
||||
else:
|
||||
logger.debug(f"Keyword already exists: {keyword}")
|
||||
|
||||
def remove_keyword(self, keyword):
|
||||
if keyword in self.keywords:
|
||||
self.keywords.remove(keyword)
|
||||
self.save_keywords()
|
||||
self.load_data_from_file() # 重新加载以确保数据同步
|
||||
self.db.remove_keyword(keyword)
|
||||
self.load_data_from_file()
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_keywords_containing(self, substring):
|
||||
# 记录原始关键词列表的长度
|
||||
original_count = len(self.keywords)
|
||||
|
||||
# 创建一个列表,包含所有需要移除的关键词
|
||||
removed_keywords = [
|
||||
kw for kw in self.keywords if substring.lower() in kw.lower()
|
||||
]
|
||||
|
||||
# 修改关键词列表,仅保留不包含指定子字符串的关键词
|
||||
self.keywords = [
|
||||
kw for kw in self.keywords if substring.lower() not in kw.lower()
|
||||
]
|
||||
|
||||
# 如果有关键词被移除,则保存关键词列表并重新加载数据
|
||||
if removed_keywords:
|
||||
self.save_keywords()
|
||||
self.load_data_from_file()
|
||||
|
||||
# 返回被移除的关键词列表
|
||||
removed_keywords = [kw for kw in self.keywords if substring.lower() in kw.lower()]
|
||||
for keyword in removed_keywords:
|
||||
self.db.remove_keyword(keyword)
|
||||
self.load_data_from_file()
|
||||
return removed_keywords
|
||||
|
||||
def should_filter(self, text):
|
||||
@ -135,7 +98,7 @@ class LinkFilter:
|
||||
new_non_whitelisted_links = []
|
||||
for link in links:
|
||||
normalized_link = self.normalize_link(link)
|
||||
normalized_link = normalized_link.lstrip("/") # 去除开头的双斜杠
|
||||
normalized_link = normalized_link.lstrip("/")
|
||||
if not self.is_whitelisted(normalized_link):
|
||||
logger.debug(f"Link not whitelisted: {normalized_link}")
|
||||
if normalized_link not in self.keywords:
|
||||
@ -169,9 +132,7 @@ class LinkFilter:
|
||||
if self.remove_keyword(keyword):
|
||||
await event.reply(f"关键词 '{keyword}' 已删除。")
|
||||
else:
|
||||
similar_keywords = [
|
||||
k for k in self.keywords if keyword.lower() in k.lower()
|
||||
]
|
||||
similar_keywords = [k for k in self.keywords if keyword.lower() in k.lower()]
|
||||
if similar_keywords:
|
||||
await send_long_message(
|
||||
event,
|
||||
@ -204,8 +165,7 @@ class LinkFilter:
|
||||
elif command == "/addwhite" and args:
|
||||
domain = args[0].lower()
|
||||
if domain not in self.whitelist:
|
||||
self.whitelist.append(domain)
|
||||
self.save_whitelist()
|
||||
self.db.add_whitelist(domain)
|
||||
self.load_data_from_file()
|
||||
await event.reply(f"域名 '{domain}' 已添加到白名单。")
|
||||
else:
|
||||
@ -213,8 +173,7 @@ class LinkFilter:
|
||||
elif command == "/delwhite" and args:
|
||||
domain = args[0].lower()
|
||||
if domain in self.whitelist:
|
||||
self.whitelist.remove(domain)
|
||||
self.save_whitelist()
|
||||
self.db.remove_whitelist(domain)
|
||||
self.load_data_from_file()
|
||||
await event.reply(f"域名 '{domain}' 已从白名单中删除。")
|
||||
else:
|
||||
|
29
src/migrate.py
Normal file
29
src/migrate.py
Normal file
@ -0,0 +1,29 @@
|
||||
import json
|
||||
import os
|
||||
from database import Database
|
||||
|
||||
def migrate_data(json_file, db_file):
|
||||
# 确保 data 目录存在
|
||||
os.makedirs(os.path.dirname(db_file), exist_ok=True)
|
||||
|
||||
# 创建数据库连接
|
||||
db = Database(db_file)
|
||||
|
||||
# 读取 JSON 文件
|
||||
with open(json_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 迁移关键词
|
||||
for keyword in data.get('keywords', []):
|
||||
db.add_keyword(keyword)
|
||||
|
||||
# 迁移白名单
|
||||
for domain in data.get('whitelist', []):
|
||||
db.add_whitelist(domain)
|
||||
|
||||
print(f"迁移完成。关键词:{len(data.get('keywords', []))}个,白名单:{len(data.get('whitelist', []))}个")
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_file = 'keywords.json' # 旧的 JSON 文件路径
|
||||
db_file = os.path.join('data', 'q58.db') # 新的数据库文件路径
|
||||
migrate_data(json_file, db_file)
|
Loading…
x
Reference in New Issue
Block a user