diff --git a/setup.cfg b/setup.cfg index 1bdeaca..8dae159 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.3.0 +version = 0.3.1 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 3a3164e..e0b9005 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -31,74 +31,121 @@ from typing import Optional, Union from .client import ControlClient, UnixControlClient from .pool import TaskPool, SimpleTaskPool from .session import ControlSession +from .types import ConnectedCallbackT log = logging.getLogger(__name__) class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool + """ + Abstract base class for a task pool control server. + + This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client + connecting to it. The entire interface is defined within that session class. + """ _client_class = ControlClient @classmethod @property def client_class_name(cls) -> str: + """Returns the name of the control client class matching the server class.""" return cls._client_class.__name__ @abstractmethod - async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: + async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: + """ + Initializes, starts, and returns an async server instance (Unix or TCP type). + + Args: + client_connected_cb: + The callback for when a client connects to the server (as per `asyncio.start_server` or + `asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method. + **kwargs (optional): + Keyword arguments to pass into the function that starts the server. + + Returns: + The running server object (a type of `asyncio.Server`). + """ raise NotImplementedError @abstractmethod - def final_callback(self) -> None: + def _final_callback(self) -> None: + """The method to run after the server's `serve_forever` methods ends for whatever reason.""" raise NotImplementedError def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: + """ + Initializes by merely saving the internal attributes, but without starting the server yet. + The task pool must be passed here and can not be set/changed afterwards. This means a control server is always + tied to one specific task pool. + + Args: + pool: + An instance of a `BaseTaskPool` subclass to tie the server to. + **server_kwargs (optional): + Keyword arguments that will be passed into the function that starts the server. + """ self._pool: Union[TaskPool, SimpleTaskPool] = pool self._server_kwargs = server_kwargs self._server: Optional[AbstractServer] = None - def __str__(self) -> str: - return f"{self.__class__.__name__} for {self._pool}" - @property def pool(self) -> Union[TaskPool, SimpleTaskPool]: + """Read-only property for accessing the task pool instance controlled by the server.""" return self._pool def is_serving(self) -> bool: + """Wrapper around the `asyncio.Server.is_serving` method.""" return self._server.is_serving() async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: + """ + The universal client callback that will be passed into the `_get_server_instance` method. + Instantiates a control session, performs the client handshake, and enters the session's `listen` loop. + """ session = ControlSession(self, reader, writer) await session.client_handshake() await session.listen() async def _serve_forever(self) -> None: + """ + To be run as an `asyncio.Task` by the following method. + Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback` + method is called, when the former method ends for whatever reason. + """ try: async with self._server: await self._server.serve_forever() except CancelledError: log.debug("%s stopped", self.__class__.__name__) finally: - self.final_callback() + self._final_callback() async def serve_forever(self) -> Task: + """ + This method actually starts the server and begins listening to client connections on the specified interface. + It should never block because the serving will be performed in a separate task. + """ log.debug("Starting %s...", self.__class__.__name__) - self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs) + self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs) return create_task(self._serve_forever()) class UnixControlServer(ControlServer): + """Task pool control server class that exposes a unix socket for control clients to connect to.""" _client_class = UnixControlClient def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: self._socket_path = Path(server_kwargs.pop('path')) super().__init__(pool, **server_kwargs) - async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: - srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) + async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: + server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) log.debug("Opened socket '%s'", str(self._socket_path)) - return srv + return server - def final_callback(self) -> None: + def _final_callback(self) -> None: + """Removes the unix socket on which the server was listening.""" self._socket_path.unlink() log.debug("Removed socket '%s'", str(self._socket_path)) diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index bca1c33..24f373f 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -34,4 +34,5 @@ CoroutineFunc = Callable[[...], Awaitable[Any]] EndCallbackT = Callable CancelCallbackT = Callable +ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..1e5030b --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,143 @@ +import asyncio +import logging +from pathlib import Path +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from asyncio_taskpool import server +from asyncio_taskpool.client import ControlClient, UnixControlClient + + +FOO, BAR = 'foo', 'bar' + + +class ControlServerTestCase(IsolatedAsyncioTestCase): + log_lvl: int + + @classmethod + def setUpClass(cls) -> None: + cls.log_lvl = server.log.level + server.log.setLevel(999) + + @classmethod + def tearDownClass(cls) -> None: + server.log.setLevel(cls.log_lvl) + + def setUp(self) -> None: + self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set()) + self.mock_abstract_methods = self.abstract_patcher.start() + self.mock_pool = MagicMock() + self.kwargs = {FOO: 123, BAR: 456} + self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs) + + def tearDown(self) -> None: + self.abstract_patcher.stop() + + def test_client_class_name(self): + self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name) + + async def test_abstract(self): + with self.assertRaises(NotImplementedError): + args = [AsyncMock()] + await self.server._get_server_instance(*args) + with self.assertRaises(NotImplementedError): + self.server._final_callback() + + def test_init(self): + self.assertEqual(self.mock_pool, self.server._pool) + self.assertEqual(self.kwargs, self.server._server_kwargs) + self.assertIsNone(self.server._server) + + def test_pool(self): + self.assertEqual(self.mock_pool, self.server.pool) + + def test_is_serving(self): + self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR)) + self.assertEqual(FOO + BAR, self.server.is_serving()) + + @patch.object(server, 'ControlSession') + async def test__client_connected_cb(self, mock_client_session_cls: MagicMock): + mock_client_handshake, mock_listen = AsyncMock(), AsyncMock() + mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen) + mock_reader, mock_writer = MagicMock(), MagicMock() + self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer)) + mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer) + mock_client_handshake.assert_awaited_once_with() + mock_listen.assert_awaited_once_with() + + @patch.object(server.ControlServer, '_final_callback') + async def test__serve_forever(self, mock__final_callback: MagicMock): + mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError) + self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever) + with self.assertLogs(server.log, logging.DEBUG): + self.assertIsNone(await self.server._serve_forever()) + mock_aenter.assert_awaited_once_with() + mock_serve_forever.assert_awaited_once_with() + mock__final_callback.assert_called_once_with() + + mock_aenter.reset_mock() + mock_serve_forever.reset_mock(side_effect=True) + mock__final_callback.reset_mock() + + self.assertIsNone(await self.server._serve_forever()) + mock_aenter.assert_awaited_once_with() + mock_serve_forever.assert_awaited_once_with() + mock__final_callback.assert_called_once_with() + + @patch.object(server, 'create_task') + @patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock()) + @patch.object(server.ControlServer, '_get_server_instance') + async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock, + mock_create_task: MagicMock): + mock__serve_forever.return_value = mock_awaitable = 'some_coroutine' + mock_create_task.return_value = expected_output = 12345 + output = await self.server.serve_forever() + self.assertEqual(expected_output, output) + mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs) + mock__serve_forever.assert_called_once_with() + mock_create_task.assert_called_once_with(mock_awaitable) + + +class UnixControlServerTestCase(IsolatedAsyncioTestCase): + log_lvl: int + + @classmethod + def setUpClass(cls) -> None: + cls.log_lvl = server.log.level + server.log.setLevel(999) + + @classmethod + def tearDownClass(cls) -> None: + server.log.setLevel(cls.log_lvl) + + def setUp(self) -> None: + self.base_init_patcher = patch.object(server.ControlServer, '__init__') + self.mock_base_init = self.base_init_patcher.start() + self.mock_pool = MagicMock() + self.path = '/tmp/asyncio_taskpool' + self.kwargs = {FOO: 123, BAR: 456} + self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs) + + def tearDown(self) -> None: + self.base_init_patcher.stop() + + def test__client_class(self): + self.assertEqual(UnixControlClient, self.server._client_class) + + def test_init(self): + self.assertEqual(Path(self.path), self.server._socket_path) + self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs) + + @patch.object(server, 'start_unix_server') + async def test__get_server_instance(self, mock_start_unix_server: AsyncMock): + mock_start_unix_server.return_value = expected_output = 'totally_a_server' + mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2} + args = [mock_callback] + output = await self.server._get_server_instance(*args, **mock_kwargs) + self.assertEqual(expected_output, output) + mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs) + + def test__final_callback(self): + self.server._socket_path = MagicMock() + self.assertIsNone(self.server._final_callback()) + self.server._socket_path.unlink.assert_called_once_with()