big rework of the session-parser-interaction;

dynamically adding pool methods/properties as parser commands;
dynamically executing selected pool method/property;
greatly simplified `ControlSession` class;
removed the need for hard-coded command names;
adjusted unittests accordingly
This commit is contained in:
Daniil Fajnberg 2022-03-13 14:56:56 +01:00
parent eb152e4d75
commit c72a5035ea
11 changed files with 702 additions and 755 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.5.1
version = 0.6.0
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -27,32 +27,12 @@ DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
SESSION_WRITER = 'session_writer'
STREAM_WRITER = 'stream_writer'
CMD = 'command'
CMD_OK = b"ok"
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'
class CMD:
__slots__ = ()
# Base commands:
CMD = 'command'
NAME = 'name'
POOL_SIZE = 'pool-size'
IS_LOCKED = 'is-locked'
LOCK = 'lock'
UNLOCK = 'unlock'
NUM_RUNNING = 'num-running'
NUM_CANCELLATIONS = 'num-cancellations'
NUM_ENDED = 'num-ended'
NUM_FINISHED = 'num-finished'
IS_FULL = 'is-full'
GET_GROUP_IDS = 'get-group-ids'
# Simple commands:
START = 'start'
STOP = 'stop'
STOP_ALL = 'stop-all'
FUNC_NAME = 'func-name'

View File

@ -63,13 +63,5 @@ class ServerException(Exception):
pass
class UnknownTaskPoolClass(ServerException):
pass
class NotATaskPool(ServerException):
pass
class HelpRequested(ServerException):
pass

View File

@ -51,10 +51,6 @@ async def join_queue(q: Queue) -> None:
await q.join()
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0].strip()

View File

@ -0,0 +1,299 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the `ControlParser` class used by a control server.
"""
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature
from shutil import get_terminal_size
from typing import Callable, Container, Dict, Set, Type, TypeVar
from .constants import CLIENT_INFO, CMD, STREAM_WRITER
from .exceptions import HelpRequested
from .helpers import get_first_doc_line
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
ParsersDict = Dict[str, 'ControlParser']
OMIT_PARAMS_DEFAULT = ('self', )
FORMATTER_CLASS = 'formatter_class'
NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description'
class ControlParser(ArgumentParser):
"""
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client.
Finally, it offers some convenience methods and makes use of custom exceptions.
"""
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
"""
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
as convenient, when making use of sub-parsers.
Args:
terminal_width:
The number of columns of the terminal to which to adjust help formatting.
base_cls (optional):
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
Returns:
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
"""
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None,
**kwargs) -> None:
"""
Sets additional internal attributes depending on whether a parent-parser was defined.
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
Args:
stream_writer:
The instance of the `asyncio.StreamWriter` to use for message output.
terminal_width (optional):
The terminal width to assume for all message formatting. Defaults to `shutil.get_terminal_size`.
**kwargs(optional):
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
is empty.
"""
self._stream_writer: StreamWriter = stream_writer
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs)
self._flags: Set[str] = set()
self._commands = None
def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT,
**subparser_kwargs) -> 'ControlParser':
"""
Takes a function along with its parameters and adds a corresponding (sub-)command to the parser.
The `add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any public method of the `SimpleTaskPool` class.
Args:
function:
The reference to the function to be "converted" to a parser command.
omit_params (optional):
Names of function parameters not to add as parser arguments.
**subparser_kwargs (optional):
Passed directly to the `add_parser` method.
Returns:
The subparser instance created from the function.
"""
subparser_kwargs.setdefault(NAME, function.__name__.replace('_', '-'))
subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME])
subparser_kwargs.setdefault(HELP, get_first_doc_line(function))
subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP])
subparser: ControlParser = self._commands.add_parser(**subparser_kwargs)
subparser.add_function_args(function, omit_params)
return subparser
def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser':
"""
Same as the `add_function_command` method, but for properties.
Args:
prop:
The reference to the property to be "converted" to a parser command.
cls_name (optional):
Name of the class the property is defined on to appear in the command help text.
**subparser_kwargs (optional):
Passed directly to the `add_parser` method.
Returns:
The subparser instance created from the property.
"""
subparser_kwargs.setdefault(NAME, prop.fget.__name__.replace('_', '-'))
subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME])
getter_help = get_first_doc_line(prop.fget)
if prop.fset is None:
subparser_kwargs.setdefault(HELP, getter_help)
else:
subparser_kwargs.setdefault(HELP, f"Get/set the `{cls_name}.{subparser_kwargs[NAME]}` property")
subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP])
subparser: ControlParser = self._commands.add_parser(**subparser_kwargs)
if prop.fset is not None:
_, param = signature(prop.fset).parameters.values()
setter_arg_help = f"If provided: {get_first_doc_line(prop.fset)} If omitted: {getter_help}"
subparser.add_function_arg(param, nargs='?', default=SUPPRESS, help=setter_arg_help)
return subparser
def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (),
member_arg_name: str = CMD) -> ParsersDict:
"""
Takes a class and adds its methods and properties as (sub-)commands to the parser.
The `add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments.
This method works correctly with the `SimpleTaskPool` class.
Args:
cls:
The reference to the class whose methods/properties are to be "converted" to parser commands.
public_only (optional):
If `False`, protected and private members are considered as well. `True` by default.
omit_members (optional):
Names of functions/properties not to add as parser commands.
member_arg_name (optional):
After parsing the arguments, depending on which command was invoked by the user, the corresponding
method/property will be stored as an extra argument in the parsed namespace under this attribute name.
Defaults to `constants.CMD`.
Returns:
Dictionary mapping class member names to the (sub-)parsers created from them.
"""
parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only):
continue
if isfunction(member):
subparser = self.add_function_command(member, **common_kwargs)
elif isinstance(member, property):
subparser = self.add_property_command(member, cls.__name__, **common_kwargs)
else:
continue
subparser.set_defaults(**{member_arg_name: member})
parsers[name] = subparser
return parsers
def add_subparsers(self, *args, **kwargs):
"""Adds the subparsers action as an internal attribute before returning it."""
self._commands = super().add_subparsers(*args, **kwargs)
return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
if message:
self._stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked."""
if message:
self._print_message(message)
def error(self, message: str) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().error(message=message)
raise HelpRequested
def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().print_help(file)
raise HelpRequested
def add_function_arg(self, parameter: Parameter, **kwargs) -> Action:
"""
Takes an `inspect.Parameter` of a function and adds a corresponding argument to the parser.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any parameter of any public method of the `SimpleTaskPool` class.
Args:
parameter: The `inspect.Parameter` object to be converted to a parser argument.
**kwargs: Passed to the `add_argument` method of the base class.
Returns:
The `argparse.Action` returned by the `add_argument` method.
"""
if parameter.default is Parameter.empty:
# A non-optional function parameter should correspond to a positional argument.
name_or_flags = [parameter.name]
else:
flag = None
long = f'--{parameter.name.replace("_", "-")}'
# We try to generate a short version (flag) for the argument.
letter = parameter.name[0]
if letter not in self._flags:
flag = f'-{letter}'
self._flags.add(letter)
elif letter.upper() not in self._flags:
flag = f'-{letter.upper()}'
self._flags.add(letter.upper())
name_or_flags = [long] if flag is None else [flag, long]
if parameter.annotation is bool:
# If we are dealing with a boolean parameter, always use the 'store_true' action.
# Even if the parameter's default value is `True`, this will make the parser argument's default `False`.
kwargs.setdefault('action', 'store_true')
else:
# For now, any other type annotation will implicitly use the default action 'store'.
# In addition, we always set the default value.
kwargs.setdefault('default', parameter.default)
if parameter.kind == Parameter.VAR_POSITIONAL:
# This is to be able to later unpack an arbitrary number of positional arguments.
kwargs.setdefault('nargs', '*')
if not kwargs.get('action') == 'store_true':
# The lambda wrapper around the type annotation is to avoid ValueError being raised on suppressed arguments.
# See: https://bugs.python.org/issue36078
kwargs.setdefault('type', get_arg_type_wrapper(parameter.annotation))
return self.add_argument(*name_or_flags, **kwargs)
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
"""
Takes a function reference and adds its parameters as arguments to the parser.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any public method of the `SimpleTaskPool` class.
Args:
function:
The function whose parameters are to be converted to parser arguments.
Its parameters must be properly annotated.
omit (optional):
Names of function parameters not to add as parser arguments.
"""
for param in signature(function).parameters.values():
if param.name not in omit:
# TODO: Look into parsing docstrings properly to try and extract argument help text.
# For now, the argument help just shows the type it will be converted to.
self.add_function_arg(param, help=repr(param.annotation))
def get_arg_type_wrapper(cls: Type) -> Callable:
def wrapper(arg):
return arg if arg is SUPPRESS else cls(arg)
return wrapper

View File

@ -15,21 +15,22 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the control session class used by the control server.
This module contains the the definition of the `ControlSession` class used by the control server.
"""
import logging
import json
from argparse import ArgumentError, HelpFormatter
from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from inspect import isfunction, signature
from typing import Callable, Optional, Union, TYPE_CHECKING
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM
from .constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from .exceptions import HelpRequested
from .helpers import return_or_exception
from .pool import TaskPool, SimpleTaskPool
from .parser import ControlParser
if TYPE_CHECKING:
from .server import ControlServer
@ -67,273 +68,88 @@ class ControlSession:
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None
self._subparsers = None
self._parser: Optional[ControlParser] = None
def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
**kwargs) -> CommandParser:
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
"""
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
Takes a pool method reference, executes it, and writes a response accordingly.
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
If the first parameter is named `self`, the method will be called with the `_pool` instance as its first
positional argument. If it returns nothing, the response upon successful execution will be `constants.CMD_OK`,
otherwise the response written to the stream will be its return value (as an encoded string).
Args:
name:
The command name; passed directly into the `add_parser` method.
prog (optional):
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
equal to the `name` argument.
short_help (optional):
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
long_help (optional):
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
prop:
The reference to the method defined on the `_pool` instance's class.
**kwargs (optional):
Any keyword-arguments to directly pass into the `add_parser` method.
Returns:
An instance of the `CommandParser` class representing the newly added control command.
Must correspond to the arguments expected by the `method`.
Correctly unpacks arbitrary-length positional and keyword-arguments.
"""
if prog is None:
prog = name
kwargs.setdefault('help', short_help or long_help)
kwargs.setdefault('description', long_help or short_help)
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
log.warning("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
normal_pos, var_pos = [], []
for param in signature(method).parameters.values():
if param.name == 'self':
normal_pos.append(self._pool)
elif param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
normal_pos.append(kwargs.pop(param.name))
elif param.kind == param.VAR_POSITIONAL:
var_pos = kwargs.pop(param.name)
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
self._writer.write(CMD_OK if output is None else str(output).encode())
def _add_base_commands(self) -> None:
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
"""
Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
Takes a pool property reference, executes its setter or getter, and writes a response accordingly.
These include commands mapping to the following pool methods:
- __str__
- pool_size (get/set property)
- is_locked
- lock & unlock
- num_running
"""
cls: Type[BaseTaskPool] = self._pool.__class__
self._add_command(CMD.NAME, short_help=get_first_doc_line(cls.__str__))
self._add_command(
CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter
).add_optional_num_argument(
default=None,
help=f"If passed a number: {get_first_doc_line(cls.pool_size.fset)} "
f"If omitted: {get_first_doc_line(cls.pool_size.fget)}"
)
self._add_command(CMD.IS_LOCKED, short_help=get_first_doc_line(cls.is_locked.fget))
self._add_command(CMD.LOCK, short_help=get_first_doc_line(cls.lock))
self._add_command(CMD.UNLOCK, short_help=get_first_doc_line(cls.unlock))
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(cls.num_running.fget))
self._add_command(CMD.NUM_CANCELLATIONS, short_help=get_first_doc_line(cls.num_cancellations.fget))
self._add_command(CMD.NUM_ENDED, short_help=get_first_doc_line(cls.num_ended.fget))
self._add_command(CMD.NUM_FINISHED, short_help=get_first_doc_line(cls.num_finished.fget))
self._add_command(CMD.IS_FULL, short_help=get_first_doc_line(cls.is_full.fget))
self._add_command(
CMD.GET_GROUP_IDS, short_help=get_first_doc_line(cls.get_group_ids)
).add_argument(
'group_name',
nargs='*',
help="Must be a name of a task group that exists within the pool."
)
def _add_simple_commands(self) -> None:
"""
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
These include commands mapping to the following pool methods:
- start
- stop
- stop_all
- func_name
"""
self._add_command(
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
).add_optional_num_argument(
help="Number of tasks to start."
)
self._add_command(
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
).add_optional_num_argument(
help="Number of tasks to stop."
)
self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
def _add_advanced_commands(self) -> None:
"""
Adds the commands that are only supported, if a `TaskPool` object is controlled.
These include commands mapping to the following pool methods:
- ...
"""
raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None:
"""
Initializes and fully configures the `CommandParser` responsible for handling the input.
Depending on what specific task pool class is controlled by the server, different commands are added.
The property set/get method will always be called with the `_pool` instance as its first positional argument.
Args:
client_terminal_width:
The number of columns of the client's terminal to be able to nicely format messages from the parser.
prop:
The reference to the property defined on the `_pool` instance's class.
**kwargs (optional):
If not empty, the property setter is executed and the keyword arguments are passed along to it; the
response upon successful execution will be `constants.CMD_OK`. Otherwise the property getter is
executed and the response written to the stream will be its return value (as an encoded string).
"""
parser_kwargs = {
'prog': '',
SESSION_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
self._add_base_commands()
if isinstance(self._pool, TaskPool):
self._add_advanced_commands()
elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands()
elif isinstance(self._pool, BaseTaskPool):
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
if kwargs:
log.warning("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK)
else:
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
log.warning("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
async def client_handshake(self) -> None:
"""
This method must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
Client info is retrieved, server info is sent back, and the `ControlParser` is initialized and configured.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
parser_kwargs = {
STREAM_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '',
'usage': f'%(prog)s [-h] [{CMD}] ...'
}
self._parser = ControlParser(**parser_kwargs)
self._parser.add_subparsers(title="Commands",
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
self._parser.add_class_commands(self._pool.__class__)
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
"""
Acts as a wrapper around a call to a specific task pool method.
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
the exception caught is sent back.
Args:
func:
Reference to the task pool method.
*args (optional):
Any positional arguments to call the method with.
*+kwargs (optional):
Any keyword-arguments to call the method with.
"""
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
"""Maps to the `__str__` method of any task pool class."""
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
"""Maps to the `pool_size` property of any task pool class."""
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_is_locked(self, **_kwargs) -> None:
"""Maps to the `is_locked` property of any task pool class."""
log.debug("%s checks locked status", self._client_class_name)
await self._write_function_output(self._pool.__class__.is_locked.fget, self._pool)
async def _cmd_lock(self, **_kwargs) -> None:
"""Maps to the `lock` method of any task pool class."""
log.debug("%s requests locking the pool", self._client_class_name)
await self._write_function_output(self._pool.lock)
async def _cmd_unlock(self, **_kwargs) -> None:
"""Maps to the `unlock` method of any task pool class."""
log.debug("%s requests unlocking the pool", self._client_class_name)
await self._write_function_output(self._pool.unlock)
async def _cmd_num_running(self, **_kwargs) -> None:
"""Maps to the `num_running` property of any task pool class."""
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_num_cancellations(self, **_kwargs) -> None:
"""Maps to the `num_cancellations` property of any task pool class."""
log.debug("%s requests number of cancelled tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_cancellations.fget, self._pool)
async def _cmd_num_ended(self, **_kwargs) -> None:
"""Maps to the `num_ended` property of any task pool class."""
log.debug("%s requests number of ended tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_ended.fget, self._pool)
async def _cmd_num_finished(self, **_kwargs) -> None:
"""Maps to the `num_finished` property of any task pool class."""
log.debug("%s requests number of finished tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_finished.fget, self._pool)
async def _cmd_is_full(self, **_kwargs) -> None:
"""Maps to the `is_full` property of any task pool class."""
log.debug("%s checks full status", self._client_class_name)
await self._write_function_output(self._pool.__class__.is_full.fget, self._pool)
async def _cmd_get_group_ids(self, **kwargs) -> None:
"""Maps to the `get_group_ids` method of any task pool class."""
log.debug("%s requests task ids for groups %s", self._client_class_name, kwargs['group_name'])
await self._write_function_output(self._pool.get_group_ids, *kwargs['group_name'])
async def _cmd_start(self, **kwargs) -> None:
"""Maps to the `start` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, **kwargs) -> None:
"""
Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
Args:
**kwargs:
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
simply passed into the method determined from the command name.
"""
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
await method(**kwargs)
async def _parse_command(self, msg: str) -> None:
"""
Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
`ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the entire
dictionary of keyword-arguments returned by the `ControlParser` passed into it.
Args:
msg:
The non-empty string read from the client stream.
msg: The non-empty string read from the client stream.
"""
try:
kwargs = vars(self._parser.parse_args(msg.split(' ')))
@ -342,7 +158,11 @@ class ControlSession:
return
except HelpRequested:
return
await self._execute_command(**kwargs)
command = kwargs.pop(CMD)
if isfunction(command):
await self._exec_method_and_respond(command, **kwargs)
elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs)
async def listen(self) -> None:
"""

View File

@ -1,127 +0,0 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the `CommandParser` class used in a control server session.
"""
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter
from typing import Type, TypeVar
from .constants import SESSION_WRITER, CLIENT_INFO
from .exceptions import HelpRequested
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
FORMATTER_CLASS = 'formatter_class'
NUM = 'num'
class CommandParser(ArgumentParser):
"""
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client.
Finally, it offers some convenience methods and makes use of custom exceptions.
"""
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
"""
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
as convenient, when making use of sub-parsers.
Args:
terminal_width:
The number of columns of the terminal to which to adjust help formatting.
base_cls (optional):
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
Returns:
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
"""
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None:
"""
Sets additional internal attributes depending on whether a parent-parser was defined.
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
Args:
parent (optional):
An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method
of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present,
the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`.
**kwargs(optional):
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
is empty.
"""
self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs)
@property
def session_writer(self) -> StreamWriter:
"""Returns the predefined stream writer object of the control session."""
return self._session_writer
@property
def terminal_width(self) -> int:
"""Returns the predefined terminal width."""
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
if message:
self._session_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked."""
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().print_help(file)
raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
"""Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified."""
if not name_or_flags:
name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?')
kwargs.setdefault('default', 1)
kwargs.setdefault('type', int)
return self.add_argument(*name_or_flags, **kwargs)

View File

@ -87,14 +87,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_task_str(self):
self.assertEqual("task", helpers.tasks_str(1))
self.assertEqual("tasks", helpers.tasks_str(0))
self.assertEqual("tasks", helpers.tasks_str(-1))
self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42))
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}

259
tests/test_parser.py Normal file
View File

@ -0,0 +1,259 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.control.parser` module.
"""
from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS
from inspect import signature
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
from asyncio_taskpool import parser
from asyncio_taskpool.exceptions import HelpRequested
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(TestCase):
def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420
self.kwargs = {
'stream_writer': self.stream_writer,
'terminal_width': self.terminal_width,
parser.FORMATTER_CLASS: FOO
}
self.parser = parser.ControlParser(**self.kwargs)
def tearDown(self) -> None:
self.help_formatter_factory_patcher.stop()
def test_help_formatter_factory(self):
self.help_formatter_factory_patcher.stop()
class MockBaseClass(HelpFormatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
terminal_width = 123456789
cls = parser.ControlParser.help_formatter_factory(terminal_width, MockBaseClass)
self.assertTrue(issubclass(cls, MockBaseClass))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
cls = parser.ControlParser.help_formatter_factory(terminal_width)
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer)
self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands)
@patch.object(parser, 'get_first_doc_line')
def test_add_function_command(self, mock_get_first_doc_line: MagicMock):
def foo_bar(): pass
mock_subparser = MagicMock()
mock_add_parser = MagicMock(return_value=mock_subparser)
self.parser._commands = MagicMock(add_parser=mock_add_parser)
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output)
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_args.assert_called_once_with(foo_bar, to_omit)
@patch.object(parser, 'get_first_doc_line')
def test_add_property_command(self, mock_get_first_doc_line: MagicMock):
def get_prop(_self): pass
def set_prop(_self, _value): pass
prop = property(get_prop)
mock_subparser = MagicMock()
mock_add_parser = MagicMock(return_value=mock_subparser)
self.parser._commands = MagicMock(add_parser=mock_add_parser)
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop)
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_arg.assert_not_called()
mock_get_first_doc_line.reset_mock()
mock_add_parser.reset_mock()
prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_arg.assert_called_once_with(
tuple(signature(set_prop).parameters.values())[1],
nargs='?',
default=SUPPRESS,
help=f"If provided: {mock_help} If omitted: {mock_help}"
)
@patch.object(parser.ControlParser, 'add_property_command')
@patch.object(parser.ControlParser, 'add_function_command')
def test_add_class_commands(self, mock_add_function_command: MagicMock, mock_add_property_command: MagicMock):
class FooBar:
some_attribute = None
def _protected(self, _): pass
def __private(self, _): pass
def to_omit(self, _): pass
def method(self, _): pass
@property
def prop(self): return None
mock_set_defaults = MagicMock()
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer,
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output)
mock_add_function_command.assert_called_once_with(FooBar.method, **common_kwargs)
mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs)
mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})])
def test__print_message(self):
self.stream_writer.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode())
@patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock):
self.assertIsNone(self.parser.exit(123, ''))
mock__print_message.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser.exit(123, msg))
mock__print_message.assert_called_once_with(msg)
@patch.object(parser.ArgumentParser, 'error')
def test_error(self, mock_supercls_error: MagicMock):
with self.assertRaises(HelpRequested):
self.parser.error(FOO + BAR)
mock_supercls_error.assert_called_once_with(message=FOO + BAR)
@patch.object(parser.ArgumentParser, 'print_help')
def test_print_help(self, mock_print_help: MagicMock):
arg = MagicMock()
with self.assertRaises(HelpRequested):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
@patch.object(parser, 'get_arg_type_wrapper')
@patch.object(parser.ArgumentParser, 'add_argument')
def test_add_function_arg(self, mock_add_argument: MagicMock, mock_get_arg_type_wrapper: MagicMock):
mock_add_argument.return_value = expected_output = 'action'
mock_get_arg_type_wrapper.return_value = mock_type = 'fake'
foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex
bar_default, baz_default, boo_default = 1, 0.1, 1j
def func(foo: foo_type, *args: args_type, bar: bar_type = bar_default, baz: baz_type = baz_default,
boo: boo_type = boo_default, flag: bool = False):
return foo, args, bar, baz, boo, flag
param_foo, param_args, param_bar, param_baz, param_boo, param_flag = signature(func).parameters.values()
kwargs = {FOO + BAR: 'xyz'}
self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs))
mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(foo_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs))
mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(args_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs))
mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(bar_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs))
mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(baz_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs))
mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(boo_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs))
mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs)
mock_get_arg_type_wrapper.assert_not_called()
@patch.object(parser.ControlParser, 'add_function_arg')
def test_add_function_args(self, mock_add_function_arg: MagicMock):
def func(foo: str, *args: int, bar: float = 0.1):
return foo, args, bar
_, param_args, param_bar = signature(func).parameters.values()
self.assertIsNone(self.parser.add_function_args(func, omit=['foo']))
mock_add_function_arg.assert_has_calls([
call(param_args, help=repr(param_args.annotation)),
call(param_bar, help=repr(param_bar.annotation)),
])
class RestTestCase(TestCase):
def test_get_arg_type_wrapper(self):
type_wrap = parser.get_arg_type_wrapper(int)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13'))

View File

@ -25,9 +25,9 @@ from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool
FOO, BAR = 'foo', 'bar'
@ -61,236 +61,105 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
self.assertIsNone(self.session._subparsers)
def test__add_command(self):
expected_output = 123456
mock_add_parser = MagicMock(return_value=expected_output)
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
self.session._parser = MagicMock()
name, prog, short_help, long_help = 'abc', None, 'short123', None
kwargs = {'x': 1, 'y': 2}
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
parent=self.session._parser, **kwargs)
@patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs))
mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
mock_add_parser.reset_mock()
@patch.object(session, 'return_or_exception')
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
def prop_get(_): pass
def prop_set(_): pass
prop = property(prop_get, prop_set)
kwargs = {'value': 'something'}
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
prog, long_help = 'ffffff', 'so long, wow'
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
parent=self.session._parser, **kwargs)
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_add_parser.reset_mock()
mock_return_or_exception.return_value = val = 420.69
self.assertIsNone(await self.session._exec_property_and_respond(prop))
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
self.mock_writer.write.assert_called_once_with(str(val).encode())
short_help = None
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
parent=self.session._parser, **kwargs)
@patch.object(session, 'get_first_doc_line')
@patch.object(session.ControlSession, '_add_command')
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
self.assertIsNone(self.session._add_base_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
mock__add_command.reset_mock()
mock_get_first_doc_line.reset_mock()
self.assertIsNone(self.session._add_simple_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
with self.assertRaises(NotImplementedError):
self.session._add_advanced_commands()
@patch.object(session.ControlSession, '_add_simple_commands')
@patch.object(session.ControlSession, '_add_advanced_commands')
@patch.object(session.ControlSession, '_add_base_commands')
@patch.object(session, 'CommandParser')
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
mock_command_parser_cls.return_value = mock_parser = MagicMock()
self.session._pool = TaskPool()
width = 1234
expected_parser_kwargs = {
'prog': '',
SESSION_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
}
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_called_once_with()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
async def fake_coroutine(): pass
self.session._pool = SimpleTaskPool(fake_coroutine)
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_called_once_with()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
class FakeTaskPool(BaseTaskPool):
pass
self.session._pool = FakeTaskPool()
with self.assertRaises(UnknownTaskPoolClass):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
self.session._pool = MagicMock()
with self.assertRaises(NotATaskPool):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
@patch.object(session.ControlSession, '_init_parser')
async def test_client_handshake(self, mock__init_parser: MagicMock):
@patch.object(session, 'ControlParser')
async def test_client_handshake(self, mock_parser_cls: MagicMock):
mock_add_subparsers, mock_add_class_commands = MagicMock(), MagicMock()
mock_parser = MagicMock(add_subparsers=mock_add_subparsers, add_class_commands=mock_add_class_commands)
mock_parser_cls.return_value = mock_parser
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '',
'usage': f'%(prog)s [-h] [{CMD}] ...'
}
expected_subparsers_kwargs = {
'title': "Commands",
'metavar': "(A command followed by '-h' or '--help' will show command-specific help.)"
}
self.assertIsNone(await self.session.client_handshake())
self.assertEqual(mock_parser, self.session._parser)
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock__init_parser.assert_called_once_with(width)
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session, 'return_or_exception')
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
self.mock_writer.write = MagicMock()
mock_return_or_exception.return_value = None
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(b"ok")
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_return_or_exception.return_value = output = MagicMock()
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(str(output).encode())
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_name())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
)
mock__write_function_output.reset_mock()
kwargs.pop(session.NUM)
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_num_running())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.num_running.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_start(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_stop(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_stop_all())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_func_name())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.func_name.fget, self.session._pool
)
async def test__execute_command(self):
mock_method = AsyncMock()
cmd = 'this-is-a-test'
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
kwargs = {FOO: BAR, 'hello': 'python'}
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
mock_method.assert_awaited_once_with(**kwargs)
@patch.object(session.ControlSession, '_execute_command')
async def test__parse_command(self, mock__execute_command: AsyncMock):
@patch.object(session.ControlSession, '_exec_property_and_respond')
@patch.object(session.ControlSession, '_exec_method_and_respond')
async def test__parse_command(self, mock__exec_method_and_respond: AsyncMock,
mock__exec_property_and_respond: AsyncMock):
def method(_): pass
prop = property(method)
msg = 'asdf asd as a'
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_awaited_once_with(**kwargs)
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called()
mock__execute_command.reset_mock()
mock__exec_method_and_respond.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: prop}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__execute_command.assert_not_awaited()
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
@ -299,7 +168,8 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_not_awaited()
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):

View File

@ -1,134 +0,0 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session_parser` module.
"""
from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch
from asyncio_taskpool import session_parser
from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO
from asyncio_taskpool.exceptions import HelpRequested
FOO = 'foo'
class ControlServerTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.session_writer, self.terminal_width = MagicMock(), 420
self.kwargs = {
SESSION_WRITER: self.session_writer,
CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width,
session_parser.FORMATTER_CLASS: FOO
}
self.parser = session_parser.CommandParser(**self.kwargs)
def tearDown(self) -> None:
self.help_formatter_factory_patcher.stop()
def test_help_formatter_factory(self):
self.help_formatter_factory_patcher.stop()
class MockBaseClass(HelpFormatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
terminal_width = 123456789
cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass)
self.assertTrue(issubclass(cls, MockBaseClass))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
cls = session_parser.CommandParser.help_formatter_factory(terminal_width)
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.session_writer, self.parser._session_writer)
self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
def test_session_writer(self):
self.assertEqual(self.session_writer, self.parser.session_writer)
def test_terminal_width(self):
self.assertEqual(self.terminal_width, self.parser.terminal_width)
def test__print_message(self):
self.session_writer.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))
self.session_writer.write.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg))
self.session_writer.write.assert_called_once_with(msg.encode())
@patch.object(session_parser.CommandParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock):
self.assertIsNone(self.parser.exit(123, ''))
mock__print_message.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser.exit(123, msg))
mock__print_message.assert_called_once_with(msg)
@patch.object(session_parser.ArgumentParser, 'print_help')
def test_print_help(self, mock_print_help: MagicMock):
arg = MagicMock()
with self.assertRaises(HelpRequested):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
def test_add_optional_num_argument(self):
metavar = 'FOOBAR'
action = self.parser.add_optional_num_argument(metavar=metavar)
self.assertIsInstance(action, Action)
self.assertEqual('?', action.nargs)
self.assertEqual(1, action.default)
self.assertEqual(int, action.type)
self.assertEqual(metavar, action.metavar)
num = 111
kwargs = vars(self.parser.parse_args([f'{num}']))
self.assertDictEqual({session_parser.NUM: num}, kwargs)
name = f'--{FOO}'
nargs = '+'
default = 1
_type = float
required = True
dest = 'foo_bar'
action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type,
required=required, metavar=metavar, dest=dest)
self.assertIsInstance(action, Action)
self.assertEqual(nargs, action.nargs)
self.assertEqual(default, action.default)
self.assertEqual(_type, action.type)
self.assertEqual(required, action.required)
self.assertEqual(metavar, action.metavar)
self.assertEqual(dest, action.dest)
kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5']))
self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)