diff --git a/src/asyncio_taskpool/control/session.py b/src/asyncio_taskpool/control/session.py index 457e54c..ca91cac 100644 --- a/src/asyncio_taskpool/control/session.py +++ b/src/asyncio_taskpool/control/session.py @@ -27,7 +27,7 @@ from inspect import isfunction, signature from typing import Callable, Optional, Union, TYPE_CHECKING from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER -from ..exceptions import HelpRequested +from ..exceptions import CommandError, HelpRequested from ..helpers import return_or_exception from ..pool import TaskPool, SimpleTaskPool from .parser import ControlParser @@ -163,6 +163,8 @@ class ControlSession: await self._exec_method_and_respond(command, **kwargs) elif isinstance(command, property): await self._exec_property_and_respond(command, **kwargs) + else: + self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode()) async def listen(self) -> None: """ diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index e2715a7..c6858f8 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -65,3 +65,7 @@ class ServerException(Exception): class HelpRequested(ServerException): pass + + +class CommandError(ServerException): + pass diff --git a/tests/test_control/test_client.py b/tests/test_control/test_client.py index 683e5f3..fdaca43 100644 --- a/tests/test_control/test_client.py +++ b/tests/test_control/test_client.py @@ -172,6 +172,43 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): self.mock_print.assert_called_once_with("Disconnected from control server.") +class TCPControlClientTestCase(IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.base_init_patcher = patch.object(client.ControlClient, '__init__') + self.mock_base_init = self.base_init_patcher.start() + self.host, self.port = 'localhost', 12345 + self.kwargs = {FOO: 123, BAR: 456} + self.client = client.TCPControlClient(host=self.host, port=self.port, **self.kwargs) + + def tearDown(self) -> None: + self.base_init_patcher.stop() + + def test_init(self): + self.assertEqual(self.host, self.client._host) + self.assertEqual(self.port, self.client._port) + self.mock_base_init.assert_called_once_with(**self.kwargs) + + @patch.object(client, 'print') + @patch.object(client, 'open_connection') + async def test__open_connection(self, mock_open_connection: AsyncMock, mock_print: MagicMock): + mock_open_connection.return_value = expected_output = 'something' + kwargs = {'a': 1, 'b': 2} + output = await self.client._open_connection(**kwargs) + self.assertEqual(expected_output, output) + mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs) + mock_print.assert_not_called() + + mock_open_connection.reset_mock() + + mock_open_connection.side_effect = e = ConnectionError() + output1, output2 = await self.client._open_connection(**kwargs) + self.assertIsNone(output1) + self.assertIsNone(output2) + mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs) + mock_print.assert_called_once_with(str(e), file=sys.stderr) + + @skipIf(os.name == 'nt', "No Unix sockets on Windows :(") class UnixControlClientTestCase(IsolatedAsyncioTestCase): diff --git a/tests/test_control/test_parser.py b/tests/test_control/test_parser.py index ab60db8..2cd0c57 100644 --- a/tests/test_control/test_parser.py +++ b/tests/test_control/test_parser.py @@ -157,6 +157,14 @@ class ControlServerTestCase(TestCase): mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs) mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})]) + @patch.object(parser.ArgumentParser, 'add_subparsers') + def test_add_subparsers(self, mock_base_add_subparsers: MagicMock): + args, kwargs = [1, 2, 42], {FOO: 123, BAR: 456} + mock_base_add_subparsers.return_value = mock_action = MagicMock() + output = self.parser.add_subparsers(*args, **kwargs) + self.assertEqual(mock_action, output) + mock_base_add_subparsers.assert_called_once_with(*args, **kwargs) + def test__print_message(self): self.stream_writer.write = MagicMock() self.assertIsNone(self.parser._print_message('')) diff --git a/tests/test_control/test_server.py b/tests/test_control/test_server.py index b61d78d..3bd4b44 100644 --- a/tests/test_control/test_server.py +++ b/tests/test_control/test_server.py @@ -27,7 +27,7 @@ from unittest import IsolatedAsyncioTestCase, skipIf from unittest.mock import AsyncMock, MagicMock, patch from asyncio_taskpool.control import server -from asyncio_taskpool.control.client import ControlClient, UnixControlClient +from asyncio_taskpool.control.client import ControlClient, TCPControlClient, UnixControlClient FOO, BAR = 'foo', 'bar' @@ -120,6 +120,50 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): mock_create_task.assert_called_once_with(mock_awaitable) +class TCPControlServerTestCase(IsolatedAsyncioTestCase): + log_lvl: int + + @classmethod + def setUpClass(cls) -> None: + cls.log_lvl = server.log.level + server.log.setLevel(999) + + @classmethod + def tearDownClass(cls) -> None: + server.log.setLevel(cls.log_lvl) + + def setUp(self) -> None: + self.base_init_patcher = patch.object(server.ControlServer, '__init__') + self.mock_base_init = self.base_init_patcher.start() + self.mock_pool = MagicMock() + self.host, self.port = 'localhost', 12345 + self.kwargs = {FOO: 123, BAR: 456} + self.server = server.TCPControlServer(pool=self.mock_pool, host=self.host, port=self.port, **self.kwargs) + + def tearDown(self) -> None: + self.base_init_patcher.stop() + + def test__client_class(self): + self.assertEqual(TCPControlClient, self.server._client_class) + + def test_init(self): + self.assertEqual(self.host, self.server._host) + self.assertEqual(self.port, self.server._port) + self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs) + + @patch.object(server, 'start_server') + async def test__get_server_instance(self, mock_start_server: AsyncMock): + mock_start_server.return_value = expected_output = 'totally_a_server' + mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2} + args = [mock_callback] + output = await self.server._get_server_instance(*args, **mock_kwargs) + self.assertEqual(expected_output, output) + mock_start_server.assert_called_once_with(mock_callback, self.host, self.port, **mock_kwargs) + + def test__final_callback(self): + self.assertIsNone(self.server._final_callback()) + + @skipIf(os.name == 'nt', "No Unix sockets on Windows :(") class UnixControlServerTestCase(IsolatedAsyncioTestCase): log_lvl: int diff --git a/tests/test_control/test_session.py b/tests/test_control/test_session.py index fd393a0..00bff77 100644 --- a/tests/test_control/test_session.py +++ b/tests/test_control/test_session.py @@ -142,9 +142,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): mock__exec_method_and_respond.reset_mock() mock_parse_args.reset_mock() - mock_parse_args = MagicMock(return_value=Namespace(**{CMD: prop}, **kwargs)) - self.session._parser = MagicMock(parse_args=mock_parse_args) - self.mock_writer.write = MagicMock() + mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs) self.assertIsNone(await self.session._parse_command(msg)) mock_parse_args.assert_called_once_with(msg.split(' ')) self.mock_writer.write.assert_not_called() @@ -154,6 +152,21 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): mock__exec_property_and_respond.reset_mock() mock_parse_args.reset_mock() + bad_command = 'definitely not a function or property' + mock_parse_args.return_value = Namespace(**{CMD: bad_command}, **kwargs) + with patch.object(session, 'CommandError') as cmd_err_cls: + cmd_err_cls.return_value = exc = MagicMock() + self.assertIsNone(await self.session._parse_command(msg)) + cmd_err_cls.assert_called_once_with(f"Unknown command object: {bad_command}") + mock_parse_args.assert_called_once_with(msg.split(' ')) + mock__exec_method_and_respond.assert_not_called() + mock__exec_property_and_respond.assert_not_called() + self.mock_writer.write.assert_called_once_with(str(exc).encode()) + + mock__exec_property_and_respond.reset_mock() + mock_parse_args.reset_mock() + self.mock_writer.write.reset_mock() + mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") self.assertIsNone(await self.session._parse_command(msg)) mock_parse_args.assert_called_once_with(msg.split(' '))