Compare commits

..

6 Commits

12 changed files with 690 additions and 145 deletions

View File

@@ -8,5 +8,8 @@ omit =
fail_under = 100
show_missing = True
skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
omit =
tests/*

View File

@@ -14,7 +14,7 @@ If you need control over a task pool at runtime, you can launch an asynchronous
## Usage
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like:
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
```python
from asyncio_taskpool import SimpleTaskPool
...

View File

@@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.3.1
version = 0.3.3
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@@ -26,8 +26,8 @@ from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from asyncio_taskpool import constants
from asyncio_taskpool.types import ClientConnT
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT
class ControlClient(ABC):
@@ -38,7 +38,7 @@ class ControlClient(ABC):
@staticmethod
def client_info() -> dict:
return {'width': shutil.get_terminal_size().columns}
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs
@@ -48,17 +48,17 @@ class ControlClient(ABC):
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try:
msg = input("> ").strip().lower()
except EOFError:
msg = constants.CLIENT_EXIT
msg = CLIENT_EXIT
except KeyboardInterrupt:
print()
return
if msg == constants.CLIENT_EXIT:
if msg == CLIENT_EXIT:
writer.close()
self._connected = False
return
@@ -69,7 +69,7 @@ class ControlClient(ABC):
self._connected = False
print(e, file=sys.stderr)
return
print((await reader.read(constants.MSG_BYTES)).decode())
print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)

View File

@@ -20,13 +20,25 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024000
CMD = 'command'
CMD_NAME = 'name'
CMD_POOL_SIZE = 'pool-size'
CMD_NUM_RUNNING = 'num-running'
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop-all'
CMD_FUNC_NAME = 'func-name'
CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
SESSION_PARSER_WRITER = 'session_writer'
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'
class CMD:
__slots__ = ()
CMD = 'command'
NAME = 'name'
POOL_SIZE = 'pool-size'
NUM_RUNNING = 'num-running'
START = 'start'
STOP = 'stop'
STOP_ALL = 'stop-all'
FUNC_NAME = 'func-name'

View File

@@ -55,5 +55,13 @@ class ServerException(Exception):
pass
class UnknownTaskPoolClass(ServerException):
pass
class NotATaskPool(ServerException):
pass
class HelpRequested(ServerException):
pass

View File

@@ -19,7 +19,6 @@ Miscellaneous helper functions.
"""
import re
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from inspect import getdoc
@@ -57,7 +56,7 @@ def tasks_str(num: int) -> str:
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0]
return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:

View File

@@ -1,13 +1,35 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the control session class used by the control server.
"""
import logging
import json
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace
from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from typing import Callable, Optional, Union, TYPE_CHECKING
from . import constants
from .exceptions import HelpRequested
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM
if TYPE_CHECKING:
from .server import ControlServer
@@ -15,202 +37,266 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
NUM = 'num'
WIDTH = 'width'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
class ClientHelpFormatter(HelpFormatter):
def __init__(self, *args, **kwargs) -> None:
kwargs[WIDTH] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
class ControlSession:
"""
This class defines the API for controlling a task pool instance from the outside.
The commands received from a connected client are translated into method calls on the task pool instance.
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
"""
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
"""
Instantiation should happen once a client connection to the control server has already been established.
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
Args:
server:
The instance of a `ControlServer` subclass starting the session.
reader:
The `asyncio.StreamReader` created when a client connected to the server.
writer:
The `asyncio.StreamWriter` created when a client connected to the server.
"""
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None
self._subparsers = None
def _add_base_commands(self):
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD)
subparsers.add_parser(
constants.CMD_NAME,
prog=constants.CMD_NAME,
help=get_first_doc_line(self._pool.__class__.__str__),
parent=self._parser,
)
subparser_pool_size = subparsers.add_parser(
constants.CMD_POOL_SIZE,
prog=constants.CMD_POOL_SIZE,
help="Get/set the maximum number of tasks in the pool",
parent=self._parser,
)
subparser_pool_size.add_argument(
NUM,
nargs='?',
def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
**kwargs) -> CommandParser:
"""
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
Args:
name:
The command name; passed directly into the `add_parser` method.
prog (optional):
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
equal to the `name` argument.
short_help (optional):
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
long_help (optional):
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
**kwargs (optional):
Any keyword-arguments to directly pass into the `add_parser` method.
Returns:
An instance of the `CommandParser` class representing the newly added control command.
"""
if prog is None:
prog = name
kwargs.setdefault('help', short_help or long_help)
kwargs.setdefault('description', long_help or short_help)
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
def _add_base_commands(self) -> None:
"""
Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
These include commands mapping to the following pool methods:
- __str__
- pool_size (get/set property)
- num_running
"""
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
self._add_command(
CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter
).add_optional_num_argument(
default=None,
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
)
subparsers.add_parser(
constants.CMD_NUM_RUNNING,
help=get_first_doc_line(self._pool.__class__.num_running.fget),
parent=self._parser,
)
return subparsers
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
def _add_simple_commands(self):
subparsers = self._add_base_commands()
subparser = subparsers.add_parser(
constants.CMD_START,
prog=constants.CMD_START,
help=get_first_doc_line(self._pool.__class__.start),
parent=self._parser,
def _add_simple_commands(self) -> None:
"""
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
These include commands mapping to the following pool methods:
- start
- stop
- stop_all
- func_name
"""
self._add_command(
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
).add_optional_num_argument(
help="Number of tasks to start."
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to start. Defaults to 1."
)
subparser = subparsers.add_parser(
constants.CMD_STOP,
prog=constants.CMD_STOP,
help=get_first_doc_line(self._pool.__class__.stop),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to stop. Defaults to 1."
)
subparsers.add_parser(
constants.CMD_STOP_ALL,
prog=constants.CMD_STOP_ALL,
help=get_first_doc_line(self._pool.__class__.stop_all),
parent=self._parser,
)
subparsers.add_parser(
constants.CMD_FUNC_NAME,
prog=constants.CMD_FUNC_NAME,
help=get_first_doc_line(self._pool.__class__.func_name.fget),
parent=self._parser,
self._add_command(
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
).add_optional_num_argument(
help="Number of tasks to stop."
)
self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
def _add_advanced_commands(self) -> None:
"""
Adds the commands that are only supported, if a `TaskPool` object is controlled.
These include commands mapping to the following pool methods:
- ...
"""
raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None:
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width)
"""
Initializes and fully configures the `CommandParser` responsible for handling the input.
Depending on what specific task pool class is controlled by the server, different commands are added.
Args:
client_terminal_width:
The number of columns of the client's terminal to be able to nicely format messages from the parser.
"""
parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
self._add_base_commands()
if isinstance(self._pool, TaskPool):
pass # TODO
self._add_advanced_commands()
elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands()
elif isinstance(self._pool, BaseTaskPool):
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
else:
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip())
"""
This method must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[WIDTH])
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
"""
Acts as a wrapper around a call to a specific task pool method.
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
the exception caught is sent back.
Args:
func:
Reference to the task pool method.
*args (optional):
Any positional arguments to call the method with.
*+kwargs (optional):
Any keyword-arguments to call the method with.
"""
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
"""Maps to the `__str__` method of any task pool class."""
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
"""Maps to the `pool_size` property of any task pool class."""
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num))
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None:
"""Maps to the `num_running` property of any task pool class."""
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None:
"""Maps to the `start` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, args: Namespace) -> None:
args = vars(args)
cmd: str = args.pop(constants.CMD, None)
if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
await method(**args)
async def _execute_command(self, **kwargs) -> None:
"""
Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
Args:
**kwargs:
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
simply passed into the method determined from the command name.
"""
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
await method(**kwargs)
async def _parse_command(self, msg: str) -> None:
"""
Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
Args:
msg:
The non-empty string read from the client stream.
"""
try:
args, argv = self._parser.parse_known_args(msg.split(' '))
kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
self._writer.write(str(e).encode())
return
except HelpRequested:
return
if argv:
log.debug("%s sent unknown arguments: %s", self._client_class_name, msg)
self._writer.write(b"Invalid command!")
return
await self._execute_command(args)
await self._execute_command(**kwargs)
async def listen(self) -> None:
"""
Enters the main control loop that only ends if either the server or the client disconnect.
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
This method should be called, when the client connection was established and the handshake was successful.
It will obviously block indefinitely.
"""
while self._control_server.is_serving():
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip()
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break

View File

@@ -0,0 +1,60 @@
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter
from typing import Type, TypeVar
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
from .exceptions import HelpRequested
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
FORMATTER_CLASS = 'formatter_class'
NUM = 'num'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
if not name_or_flags:
name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?')
kwargs.setdefault('default', 1)
kwargs.setdefault('type', int)
return self.add_argument(*name_or_flags, **kwargs)

View File

@@ -94,3 +94,35 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42))
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}
something else
even more
""")
output = helpers.get_first_doc_line(mock_obj)
self.assertEqual(expected_output, output)
async def test_return_or_exception(self):
expected_output = '420'
mock_func = AsyncMock(return_value=expected_output)
args = (1, 3, 5)
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_awaited_once_with(*args, **kwargs)
mock_func = MagicMock(return_value=expected_output)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
class TestException(Exception):
pass
test_exception = TestException()
mock_func = MagicMock(side_effect=test_exception)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(test_exception, output)
mock_func.assert_called_once_with(*args, **kwargs)

View File

@@ -1,3 +1,24 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.server` module.
"""
import asyncio
import logging
from pathlib import Path

324
tests/test_session.py Normal file
View File

@@ -0,0 +1,324 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session` module.
"""
import json
from argparse import ArgumentError, Namespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = session.log.level
session.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
session.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
self.mock_client_class_name = FOO + BAR
self.mock_server = MagicMock(pool=self.mock_pool,
client_class_name=self.mock_client_class_name)
self.mock_reader = MagicMock()
self.mock_writer = MagicMock()
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
def test_init(self):
self.assertEqual(self.mock_server, self.session._control_server)
self.assertEqual(self.mock_pool, self.session._pool)
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
self.assertIsNone(self.session._subparsers)
def test__add_command(self):
expected_output = 123456
mock_add_parser = MagicMock(return_value=expected_output)
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
self.session._parser = MagicMock()
name, prog, short_help, long_help = 'abc', None, 'short123', None
kwargs = {'x': 1, 'y': 2}
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
prog, long_help = 'ffffff', 'so long, wow'
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
short_help = None
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
parent=self.session._parser, **kwargs)
@patch.object(session, 'get_first_doc_line')
@patch.object(session.ControlSession, '_add_command')
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
self.assertIsNone(self.session._add_base_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
mock__add_command.reset_mock()
mock_get_first_doc_line.reset_mock()
self.assertIsNone(self.session._add_simple_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
with self.assertRaises(NotImplementedError):
self.session._add_advanced_commands()
@patch.object(session.ControlSession, '_add_simple_commands')
@patch.object(session.ControlSession, '_add_advanced_commands')
@patch.object(session.ControlSession, '_add_base_commands')
@patch.object(session, 'CommandParser')
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
mock_command_parser_cls.return_value = mock_parser = MagicMock()
self.session._pool = TaskPool()
width = 1234
expected_parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
}
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_called_once_with()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
async def fake_coroutine(): pass
self.session._pool = SimpleTaskPool(fake_coroutine)
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_called_once_with()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
class FakeTaskPool(BaseTaskPool):
pass
self.session._pool = FakeTaskPool()
with self.assertRaises(UnknownTaskPoolClass):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
self.session._pool = MagicMock()
with self.assertRaises(NotATaskPool):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
@patch.object(session.ControlSession, '_init_parser')
async def test_client_handshake(self, mock__init_parser: MagicMock):
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
self.assertIsNone(await self.session.client_handshake())
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock__init_parser.assert_called_once_with(width)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session, 'return_or_exception')
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
self.mock_writer.write = MagicMock()
mock_return_or_exception.return_value = None
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(b"ok")
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_return_or_exception.return_value = output = MagicMock()
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(str(output).encode())
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_name())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
)
mock__write_function_output.reset_mock()
kwargs.pop(session.NUM)
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_num_running())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.num_running.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_start(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_stop(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_stop_all())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_func_name())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.func_name.fget, self.session._pool
)
async def test__execute_command(self):
mock_method = AsyncMock()
cmd = 'this-is-a-test'
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
kwargs = {FOO: BAR, 'hello': 'python'}
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
mock_method.assert_awaited_once_with(**kwargs)
@patch.object(session.ControlSession, '_execute_command')
async def test__parse_command(self, mock__execute_command: AsyncMock):
msg = 'asdf asd as a'
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_awaited_once_with(**kwargs)
mock__execute_command.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__execute_command.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty():
self.mock_reader.read.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg)
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.drain.assert_not_awaited()