generated from daniil-berg/boilerplate-py
docstrings and full test coverage for server
module; small adjustments
This commit is contained in:
parent
be03097bf4
commit
3fb451a00e
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.3.0
|
||||
version = 0.3.1
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -31,74 +31,121 @@ from typing import Optional, Union
|
||||
from .client import ControlClient, UnixControlClient
|
||||
from .pool import TaskPool, SimpleTaskPool
|
||||
from .session import ControlSession
|
||||
from .types import ConnectedCallbackT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||
"""
|
||||
Abstract base class for a task pool control server.
|
||||
|
||||
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
|
||||
connecting to it. The entire interface is defined within that session class.
|
||||
"""
|
||||
_client_class = ControlClient
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def client_class_name(cls) -> str:
|
||||
"""Returns the name of the control client class matching the server class."""
|
||||
return cls._client_class.__name__
|
||||
|
||||
@abstractmethod
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
"""
|
||||
Initializes, starts, and returns an async server instance (Unix or TCP type).
|
||||
|
||||
Args:
|
||||
client_connected_cb:
|
||||
The callback for when a client connects to the server (as per `asyncio.start_server` or
|
||||
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
|
||||
**kwargs (optional):
|
||||
Keyword arguments to pass into the function that starts the server.
|
||||
|
||||
Returns:
|
||||
The running server object (a type of `asyncio.Server`).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def final_callback(self) -> None:
|
||||
def _final_callback(self) -> None:
|
||||
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||
"""
|
||||
Initializes by merely saving the internal attributes, but without starting the server yet.
|
||||
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
|
||||
tied to one specific task pool.
|
||||
|
||||
Args:
|
||||
pool:
|
||||
An instance of a `BaseTaskPool` subclass to tie the server to.
|
||||
**server_kwargs (optional):
|
||||
Keyword arguments that will be passed into the function that starts the server.
|
||||
"""
|
||||
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
||||
self._server_kwargs = server_kwargs
|
||||
self._server: Optional[AbstractServer] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__} for {self._pool}"
|
||||
|
||||
@property
|
||||
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
||||
"""Read-only property for accessing the task pool instance controlled by the server."""
|
||||
return self._pool
|
||||
|
||||
def is_serving(self) -> bool:
|
||||
"""Wrapper around the `asyncio.Server.is_serving` method."""
|
||||
return self._server.is_serving()
|
||||
|
||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
The universal client callback that will be passed into the `_get_server_instance` method.
|
||||
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
|
||||
"""
|
||||
session = ControlSession(self, reader, writer)
|
||||
await session.client_handshake()
|
||||
await session.listen()
|
||||
|
||||
async def _serve_forever(self) -> None:
|
||||
"""
|
||||
To be run as an `asyncio.Task` by the following method.
|
||||
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
|
||||
method is called, when the former method ends for whatever reason.
|
||||
"""
|
||||
try:
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
except CancelledError:
|
||||
log.debug("%s stopped", self.__class__.__name__)
|
||||
finally:
|
||||
self.final_callback()
|
||||
self._final_callback()
|
||||
|
||||
async def serve_forever(self) -> Task:
|
||||
"""
|
||||
This method actually starts the server and begins listening to client connections on the specified interface.
|
||||
It should never block because the serving will be performed in a separate task.
|
||||
"""
|
||||
log.debug("Starting %s...", self.__class__.__name__)
|
||||
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||
return create_task(self._serve_forever())
|
||||
|
||||
|
||||
class UnixControlServer(ControlServer):
|
||||
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||
_client_class = UnixControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
self._socket_path = Path(server_kwargs.pop('path'))
|
||||
super().__init__(pool, **server_kwargs)
|
||||
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||
return srv
|
||||
return server
|
||||
|
||||
def final_callback(self) -> None:
|
||||
def _final_callback(self) -> None:
|
||||
"""Removes the unix socket on which the server was listening."""
|
||||
self._socket_path.unlink()
|
||||
log.debug("Removed socket '%s'", str(self._socket_path))
|
||||
|
@ -34,4 +34,5 @@ CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
EndCallbackT = Callable
|
||||
CancelCallbackT = Callable
|
||||
|
||||
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||
|
143
tests/test_server.py
Normal file
143
tests/test_server.py
Normal file
@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import server
|
||||
from asyncio_taskpool.client import ControlClient, UnixControlClient
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = server.log.level
|
||||
server.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
server.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
|
||||
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||
self.mock_pool = MagicMock()
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.abstract_patcher.stop()
|
||||
|
||||
def test_client_class_name(self):
|
||||
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
|
||||
|
||||
async def test_abstract(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
args = [AsyncMock()]
|
||||
await self.server._get_server_instance(*args)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.server._final_callback()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.mock_pool, self.server._pool)
|
||||
self.assertEqual(self.kwargs, self.server._server_kwargs)
|
||||
self.assertIsNone(self.server._server)
|
||||
|
||||
def test_pool(self):
|
||||
self.assertEqual(self.mock_pool, self.server.pool)
|
||||
|
||||
def test_is_serving(self):
|
||||
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
|
||||
self.assertEqual(FOO + BAR, self.server.is_serving())
|
||||
|
||||
@patch.object(server, 'ControlSession')
|
||||
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
|
||||
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
|
||||
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
|
||||
mock_reader, mock_writer = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
|
||||
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
|
||||
mock_client_handshake.assert_awaited_once_with()
|
||||
mock_listen.assert_awaited_once_with()
|
||||
|
||||
@patch.object(server.ControlServer, '_final_callback')
|
||||
async def test__serve_forever(self, mock__final_callback: MagicMock):
|
||||
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
|
||||
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
|
||||
with self.assertLogs(server.log, logging.DEBUG):
|
||||
self.assertIsNone(await self.server._serve_forever())
|
||||
mock_aenter.assert_awaited_once_with()
|
||||
mock_serve_forever.assert_awaited_once_with()
|
||||
mock__final_callback.assert_called_once_with()
|
||||
|
||||
mock_aenter.reset_mock()
|
||||
mock_serve_forever.reset_mock(side_effect=True)
|
||||
mock__final_callback.reset_mock()
|
||||
|
||||
self.assertIsNone(await self.server._serve_forever())
|
||||
mock_aenter.assert_awaited_once_with()
|
||||
mock_serve_forever.assert_awaited_once_with()
|
||||
mock__final_callback.assert_called_once_with()
|
||||
|
||||
@patch.object(server, 'create_task')
|
||||
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
|
||||
@patch.object(server.ControlServer, '_get_server_instance')
|
||||
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
|
||||
mock_create_task: MagicMock):
|
||||
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
|
||||
mock_create_task.return_value = expected_output = 12345
|
||||
output = await self.server.serve_forever()
|
||||
self.assertEqual(expected_output, output)
|
||||
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
|
||||
mock__serve_forever.assert_called_once_with()
|
||||
mock_create_task.assert_called_once_with(mock_awaitable)
|
||||
|
||||
|
||||
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = server.log.level
|
||||
server.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
server.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
|
||||
self.mock_base_init = self.base_init_patcher.start()
|
||||
self.mock_pool = MagicMock()
|
||||
self.path = '/tmp/asyncio_taskpool'
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.base_init_patcher.stop()
|
||||
|
||||
def test__client_class(self):
|
||||
self.assertEqual(UnixControlClient, self.server._client_class)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(Path(self.path), self.server._socket_path)
|
||||
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
|
||||
|
||||
@patch.object(server, 'start_unix_server')
|
||||
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
|
||||
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
|
||||
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
|
||||
args = [mock_callback]
|
||||
output = await self.server._get_server_instance(*args, **mock_kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
|
||||
|
||||
def test__final_callback(self):
|
||||
self.server._socket_path = MagicMock()
|
||||
self.assertIsNone(self.server._final_callback())
|
||||
self.server._socket_path.unlink.assert_called_once_with()
|
Loading…
Reference in New Issue
Block a user