diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3585af6..16b4892 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,8 +30,7 @@ jobs: - name: Run tests (if any) run: | - # Add your test commands here - # For example: python -m unittest discover tests + python -m unittest discover tests - name: Set up QEMU uses: docker/setup-qemu-action@v3 diff --git a/tests/test_guard.py b/tests/test_guard.py new file mode 100644 index 0000000..6525b5e --- /dev/null +++ b/tests/test_guard.py @@ -0,0 +1,32 @@ +import unittest +from unittest.mock import AsyncMock, patch +from guard import process_message, command_handler + +class TestGuard(unittest.TestCase): + @patch('guard.link_filter') + async def test_process_message(self, mock_link_filter): + mock_link_filter.should_filter.return_value = (True, []) + event = AsyncMock() + event.is_private = False + event.sender_id = 12345 # 非管理员ID + + await process_message(event, AsyncMock()) + + event.delete.assert_called_once() + event.respond.assert_called_once_with("已撤回该消息。注:包含关键词或重复发送的非白名单链接会被自动撤回。") + + @patch('guard.handle_command') + @patch('guard.link_filter') + async def test_command_handler(self, mock_link_filter, mock_handle_command): + event = AsyncMock() + event.is_private = True + event.sender_id = int(os.environ.get('ADMIN_ID')) + event.raw_text = '/add keyword' + + await command_handler(event, mock_link_filter) + + mock_handle_command.assert_called_once() + mock_link_filter.load_data_from_file.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_link_filter.py b/tests/test_link_filter.py new file mode 100644 index 0000000..3a66233 --- /dev/null +++ b/tests/test_link_filter.py @@ -0,0 +1,45 @@ +import unittest +import tempfile +import json +from link_filter import LinkFilter + +class TestLinkFilter(unittest.TestCase): + def setUp(self): + # 创建临时文件作为关键词和白名单文件 + self.keywords_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + self.whitelist_file = tempfile.NamedTemporaryFile(mode='w+', delete=False) + + # 写入一些初始数据 + json.dump(['example.com'], self.keywords_file) + json.dump(['google.com'], self.whitelist_file) + + self.keywords_file.close() + self.whitelist_file.close() + + self.link_filter = LinkFilter(self.keywords_file.name, self.whitelist_file.name) + + def test_normalize_link(self): + self.assertEqual(self.link_filter.normalize_link('https://www.example.com'), 'www.example.com') + self.assertEqual(self.link_filter.normalize_link('http://example.com'), 'example.com') + + def test_is_whitelisted(self): + self.assertTrue(self.link_filter.is_whitelisted('https://www.google.com')) + self.assertFalse(self.link_filter.is_whitelisted('https://www.example.com')) + + def test_should_filter(self): + should_filter, new_links = self.link_filter.should_filter('Check out https://www.example.com') + self.assertTrue(should_filter) + self.assertEqual(new_links, []) + + should_filter, new_links = self.link_filter.should_filter('Check out https://www.newsite.com') + self.assertFalse(should_filter) + self.assertEqual(new_links, ['www.newsite.com']) + + def tearDown(self): + # 删除临时文件 + import os + os.unlink(self.keywords_file.name) + os.unlink(self.whitelist_file.name) + +if __name__ == '__main__': + unittest.main()