Compare commits

...

2 Commits

8 changed files with 339 additions and 137 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.3.0 version = 0.3.2
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -26,8 +26,8 @@ from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path from pathlib import Path
from asyncio_taskpool import constants from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from asyncio_taskpool.types import ClientConnT from .types import ClientConnT
class ControlClient(ABC): class ControlClient(ABC):
@ -38,7 +38,7 @@ class ControlClient(ABC):
@staticmethod @staticmethod
def client_info() -> dict: def client_info() -> dict:
return {'width': shutil.get_terminal_size().columns} return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None: def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs self._conn_kwargs = conn_kwargs
@ -48,17 +48,17 @@ class ControlClient(ABC):
self._connected = True self._connected = True
writer.write(json.dumps(self.client_info()).encode()) writer.write(json.dumps(self.client_info()).encode())
await writer.drain() await writer.drain()
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try: try:
msg = input("> ").strip().lower() msg = input("> ").strip().lower()
except EOFError: except EOFError:
msg = constants.CLIENT_EXIT msg = CLIENT_EXIT
except KeyboardInterrupt: except KeyboardInterrupt:
print() print()
return return
if msg == constants.CLIENT_EXIT: if msg == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
@ -69,7 +69,7 @@ class ControlClient(ABC):
self._connected = False self._connected = False
print(e, file=sys.stderr) print(e, file=sys.stderr)
return return
print((await reader.read(constants.MSG_BYTES)).decode()) print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self): async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs) reader, writer = await self.open_connection(**self._conn_kwargs)

View File

@ -20,13 +20,25 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024000
CMD = 'command'
CMD_NAME = 'name'
CMD_POOL_SIZE = 'pool-size'
CMD_NUM_RUNNING = 'num-running'
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop-all'
CMD_FUNC_NAME = 'func-name'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
SESSION_PARSER_WRITER = 'session_writer'
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'
class CMD:
__slots__ = ()
CMD = 'command'
NAME = 'name'
POOL_SIZE = 'pool-size'
NUM_RUNNING = 'num-running'
START = 'start'
STOP = 'stop'
STOP_ALL = 'stop-all'
FUNC_NAME = 'func-name'

View File

@ -31,74 +31,121 @@ from typing import Optional, Union
from .client import ControlClient, UnixControlClient from .client import ControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession from .session import ControlSession
from .types import ConnectedCallbackT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
"""
Abstract base class for a task pool control server.
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
connecting to it. The entire interface is defined within that session class.
"""
_client_class = ControlClient _client_class = ControlClient
@classmethod @classmethod
@property @property
def client_class_name(cls) -> str: def client_class_name(cls) -> str:
"""Returns the name of the control client class matching the server class."""
return cls._client_class.__name__ return cls._client_class.__name__
@abstractmethod @abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
"""
Initializes, starts, and returns an async server instance (Unix or TCP type).
Args:
client_connected_cb:
The callback for when a client connects to the server (as per `asyncio.start_server` or
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
**kwargs (optional):
Keyword arguments to pass into the function that starts the server.
Returns:
The running server object (a type of `asyncio.Server`).
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def final_callback(self) -> None: def _final_callback(self) -> None:
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
raise NotImplementedError raise NotImplementedError
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
"""
Initializes by merely saving the internal attributes, but without starting the server yet.
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
tied to one specific task pool.
Args:
pool:
An instance of a `BaseTaskPool` subclass to tie the server to.
**server_kwargs (optional):
Keyword arguments that will be passed into the function that starts the server.
"""
self._pool: Union[TaskPool, SimpleTaskPool] = pool self._pool: Union[TaskPool, SimpleTaskPool] = pool
self._server_kwargs = server_kwargs self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None self._server: Optional[AbstractServer] = None
def __str__(self) -> str:
return f"{self.__class__.__name__} for {self._pool}"
@property @property
def pool(self) -> Union[TaskPool, SimpleTaskPool]: def pool(self) -> Union[TaskPool, SimpleTaskPool]:
"""Read-only property for accessing the task pool instance controlled by the server."""
return self._pool return self._pool
def is_serving(self) -> bool: def is_serving(self) -> bool:
"""Wrapper around the `asyncio.Server.is_serving` method."""
return self._server.is_serving() return self._server.is_serving()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
The universal client callback that will be passed into the `_get_server_instance` method.
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
"""
session = ControlSession(self, reader, writer) session = ControlSession(self, reader, writer)
await session.client_handshake() await session.client_handshake()
await session.listen() await session.listen()
async def _serve_forever(self) -> None: async def _serve_forever(self) -> None:
"""
To be run as an `asyncio.Task` by the following method.
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
method is called, when the former method ends for whatever reason.
"""
try: try:
async with self._server: async with self._server:
await self._server.serve_forever() await self._server.serve_forever()
except CancelledError: except CancelledError:
log.debug("%s stopped", self.__class__.__name__) log.debug("%s stopped", self.__class__.__name__)
finally: finally:
self.final_callback() self._final_callback()
async def serve_forever(self) -> Task: async def serve_forever(self) -> Task:
"""
This method actually starts the server and begins listening to client connections on the specified interface.
It should never block because the serving will be performed in a separate task.
"""
log.debug("Starting %s...", self.__class__.__name__) log.debug("Starting %s...", self.__class__.__name__)
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs) self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
return create_task(self._serve_forever()) return create_task(self._serve_forever())
class UnixControlServer(ControlServer): class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient _client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path')) self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs) super().__init__(pool, **server_kwargs)
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path)) log.debug("Opened socket '%s'", str(self._socket_path))
return srv return server
def final_callback(self) -> None: def _final_callback(self) -> None:
"""Removes the unix socket on which the server was listening."""
self._socket_path.unlink() self._socket_path.unlink()
log.debug("Removed socket '%s'", str(self._socket_path)) log.debug("Removed socket '%s'", str(self._socket_path))

View File

@ -1,13 +1,14 @@
import logging import logging
import json import json
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace from argparse import ArgumentError, HelpFormatter, Namespace
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from . import constants from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested from .exceptions import HelpRequested
from .helpers import get_first_doc_line, return_or_exception, tasks_str from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import ControlServer from .server import ControlServer
@ -15,47 +16,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
NUM = 'num'
WIDTH = 'width'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
class ClientHelpFormatter(HelpFormatter):
def __init__(self, *args, **kwargs) -> None:
kwargs[WIDTH] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
class ControlSession: class ControlSession:
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
@ -65,86 +25,65 @@ class ControlSession:
self._reader: StreamReader = reader self._reader: StreamReader = reader
self._writer: StreamWriter = writer self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None self._parser: Optional[CommandParser] = None
self._subparsers = None
def _add_base_commands(self): def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD) **kwargs) -> CommandParser:
subparsers.add_parser( if prog is None:
constants.CMD_NAME, prog = name
prog=constants.CMD_NAME, kwargs.setdefault('help', short_help or long_help)
help=get_first_doc_line(self._pool.__class__.__str__), kwargs.setdefault('description', long_help or short_help)
parent=self._parser, return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
)
subparser_pool_size = subparsers.add_parser( def _add_base_commands(self) -> None:
constants.CMD_POOL_SIZE, self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
prog=constants.CMD_POOL_SIZE, self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
help="Get/set the maximum number of tasks in the pool", self._add_parser_command(
parent=self._parser, CMD.POOL_SIZE,
) short_help="Get/set the maximum number of tasks in the pool.",
subparser_pool_size.add_argument( formatter_class=HelpFormatter
NUM, ).add_optional_num_argument(
nargs='?', default=None,
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
) )
subparsers.add_parser( self._add_parser_command(
constants.CMD_NUM_RUNNING, CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)
help=get_first_doc_line(self._pool.__class__.num_running.fget),
parent=self._parser,
) )
return subparsers
def _add_simple_commands(self): def _add_simple_commands(self) -> None:
subparsers = self._add_base_commands() self._add_parser_command(
subparser = subparsers.add_parser( CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
constants.CMD_START, ).add_optional_num_argument(
prog=constants.CMD_START, help="Number of tasks to start."
help=get_first_doc_line(self._pool.__class__.start),
parent=self._parser,
) )
subparser.add_argument( self._add_parser_command(
NUM, CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
nargs='?', ).add_optional_num_argument(
type=int, help="Number of tasks to stop."
default=1,
help="Number of tasks to start. Defaults to 1."
) )
subparser = subparsers.add_parser( self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
constants.CMD_STOP, self._add_parser_command(
prog=constants.CMD_STOP, CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
help=get_first_doc_line(self._pool.__class__.stop),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to stop. Defaults to 1."
)
subparsers.add_parser(
constants.CMD_STOP_ALL,
prog=constants.CMD_STOP_ALL,
help=get_first_doc_line(self._pool.__class__.stop_all),
parent=self._parser,
)
subparsers.add_parser(
constants.CMD_FUNC_NAME,
prog=constants.CMD_FUNC_NAME,
help=get_first_doc_line(self._pool.__class__.func_name.fget),
parent=self._parser,
) )
def _init_parser(self, client_terminal_width: int) -> None: def _init_parser(self, client_terminal_width: int) -> None:
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width) parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
self._add_base_commands()
if isinstance(self._pool, TaskPool): if isinstance(self._pool, TaskPool):
pass # TODO pass # TODO
elif isinstance(self._pool, SimpleTaskPool): elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands() self._add_simple_commands()
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[WIDTH]) self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
self._writer.write(str(self._pool).encode()) self._writer.write(str(self._pool).encode())
await self._writer.drain() await self._writer.drain()
@ -163,7 +102,7 @@ class ControlSession:
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else: else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num) log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num)) await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None: async def _cmd_num_running(self, **_kwargs) -> None:
log.debug("%s requests number of running tasks", self._client_class_name) log.debug("%s requests number of running tasks", self._client_class_name)
@ -189,7 +128,7 @@ class ControlSession:
async def _execute_command(self, args: Namespace) -> None: async def _execute_command(self, args: Namespace) -> None:
args = vars(args) args = vars(args)
cmd: str = args.pop(constants.CMD, None) cmd: str = args.pop(CMD.CMD, None)
if cmd is not None: if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
await method(**args) await method(**args)
@ -210,7 +149,7 @@ class ControlSession:
async def listen(self) -> None: async def listen(self) -> None:
while self._control_server.is_serving(): while self._control_server.is_serving():
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip() msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg: if not msg:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break

View File

@ -0,0 +1,60 @@
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter
from typing import Type, TypeVar
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
from .exceptions import HelpRequested
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
FORMATTER_CLASS = 'formatter_class'
NUM = 'num'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
if not name_or_flags:
name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?')
kwargs.setdefault('default', 1)
kwargs.setdefault('type', int)
return self.add_argument(*name_or_flags, **kwargs)

View File

@ -34,4 +34,5 @@ CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

143
tests/test_server.py Normal file
View File

@ -0,0 +1,143 @@
import asyncio
import logging
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import server
from asyncio_taskpool.client import ControlClient, UnixControlClient
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_pool = MagicMock()
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
def tearDown(self) -> None:
self.abstract_patcher.stop()
def test_client_class_name(self):
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
async def test_abstract(self):
with self.assertRaises(NotImplementedError):
args = [AsyncMock()]
await self.server._get_server_instance(*args)
with self.assertRaises(NotImplementedError):
self.server._final_callback()
def test_init(self):
self.assertEqual(self.mock_pool, self.server._pool)
self.assertEqual(self.kwargs, self.server._server_kwargs)
self.assertIsNone(self.server._server)
def test_pool(self):
self.assertEqual(self.mock_pool, self.server.pool)
def test_is_serving(self):
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
self.assertEqual(FOO + BAR, self.server.is_serving())
@patch.object(server, 'ControlSession')
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
mock_reader, mock_writer = MagicMock(), MagicMock()
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
mock_client_handshake.assert_awaited_once_with()
mock_listen.assert_awaited_once_with()
@patch.object(server.ControlServer, '_final_callback')
async def test__serve_forever(self, mock__final_callback: MagicMock):
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
with self.assertLogs(server.log, logging.DEBUG):
self.assertIsNone(await self.server._serve_forever())
mock_aenter.assert_awaited_once_with()
mock_serve_forever.assert_awaited_once_with()
mock__final_callback.assert_called_once_with()
mock_aenter.reset_mock()
mock_serve_forever.reset_mock(side_effect=True)
mock__final_callback.reset_mock()
self.assertIsNone(await self.server._serve_forever())
mock_aenter.assert_awaited_once_with()
mock_serve_forever.assert_awaited_once_with()
mock__final_callback.assert_called_once_with()
@patch.object(server, 'create_task')
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
@patch.object(server.ControlServer, '_get_server_instance')
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
mock_create_task: MagicMock):
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
mock_create_task.return_value = expected_output = 12345
output = await self.server.serve_forever()
self.assertEqual(expected_output, output)
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
mock__serve_forever.assert_called_once_with()
mock_create_task.assert_called_once_with(mock_awaitable)
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.mock_pool = MagicMock()
self.path = '/tmp/asyncio_taskpool'
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test__client_class(self):
self.assertEqual(UnixControlClient, self.server._client_class)
def test_init(self):
self.assertEqual(Path(self.path), self.server._socket_path)
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
@patch.object(server, 'start_unix_server')
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
args = [mock_callback]
output = await self.server._get_server_instance(*args, **mock_kwargs)
self.assertEqual(expected_output, output)
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
def test__final_callback(self):
self.server._socket_path = MagicMock()
self.assertIsNone(self.server._final_callback())
self.server._socket_path.unlink.assert_called_once_with()