From c72a5035eab47ddf59a2bb90ec9658b4b690ede8 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 13 Mar 2022 14:56:56 +0100 Subject: [PATCH] big rework of the session-parser-interaction; dynamically adding pool methods/properties as parser commands; dynamically executing selected pool method/property; greatly simplified `ControlSession` class; removed the need for hard-coded command names; adjusted unittests accordingly --- setup.cfg | 2 +- src/asyncio_taskpool/constants.py | 28 +-- src/asyncio_taskpool/exceptions.py | 8 - src/asyncio_taskpool/helpers.py | 4 - src/asyncio_taskpool/parser.py | 299 ++++++++++++++++++++++++ src/asyncio_taskpool/session.py | 308 +++++-------------------- src/asyncio_taskpool/session_parser.py | 127 ---------- tests/test_helpers.py | 8 - tests/test_parser.py | 259 +++++++++++++++++++++ tests/test_session.py | 280 ++++++---------------- tests/test_session_parser.py | 134 ----------- 11 files changed, 702 insertions(+), 755 deletions(-) create mode 100644 src/asyncio_taskpool/parser.py delete mode 100644 src/asyncio_taskpool/session_parser.py create mode 100644 tests/test_parser.py delete mode 100644 tests/test_session_parser.py diff --git a/setup.cfg b/setup.cfg index 8474f40..664e256 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.5.1 +version = 0.6.0 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index 5266469..c4a5523 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -27,32 +27,12 @@ DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S' CLIENT_EXIT = 'exit' SESSION_MSG_BYTES = 1024 * 100 -SESSION_WRITER = 'session_writer' + +STREAM_WRITER = 'stream_writer' +CMD = 'command' +CMD_OK = b"ok" class CLIENT_INFO: __slots__ = () TERMINAL_WIDTH = 'terminal_width' - - -class CMD: - __slots__ = () - # Base commands: - CMD = 'command' - NAME = 'name' - POOL_SIZE = 'pool-size' - IS_LOCKED = 'is-locked' - LOCK = 'lock' - UNLOCK = 'unlock' - NUM_RUNNING = 'num-running' - NUM_CANCELLATIONS = 'num-cancellations' - NUM_ENDED = 'num-ended' - NUM_FINISHED = 'num-finished' - IS_FULL = 'is-full' - GET_GROUP_IDS = 'get-group-ids' - - # Simple commands: - START = 'start' - STOP = 'stop' - STOP_ALL = 'stop-all' - FUNC_NAME = 'func-name' diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index 6f0c7d5..e2715a7 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -63,13 +63,5 @@ class ServerException(Exception): pass -class UnknownTaskPoolClass(ServerException): - pass - - -class NotATaskPool(ServerException): - pass - - class HelpRequested(ServerException): pass diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py index 912c975..8f44d37 100644 --- a/src/asyncio_taskpool/helpers.py +++ b/src/asyncio_taskpool/helpers.py @@ -51,10 +51,6 @@ async def join_queue(q: Queue) -> None: await q.join() -def tasks_str(num: int) -> str: - return "tasks" if num != 1 else "task" - - def get_first_doc_line(obj: object) -> str: return getdoc(obj).strip().split("\n", 1)[0].strip() diff --git a/src/asyncio_taskpool/parser.py b/src/asyncio_taskpool/parser.py new file mode 100644 index 0000000..d13753a --- /dev/null +++ b/src/asyncio_taskpool/parser.py @@ -0,0 +1,299 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +This module contains the the definition of the `ControlParser` class used by a control server. +""" + + +from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS +from asyncio.streams import StreamWriter +from inspect import Parameter, getmembers, isfunction, signature +from shutil import get_terminal_size +from typing import Callable, Container, Dict, Set, Type, TypeVar + +from .constants import CLIENT_INFO, CMD, STREAM_WRITER +from .exceptions import HelpRequested +from .helpers import get_first_doc_line + + +FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) +ParsersDict = Dict[str, 'ControlParser'] + +OMIT_PARAMS_DEFAULT = ('self', ) + +FORMATTER_CLASS = 'formatter_class' +NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description' + + +class ControlParser(ArgumentParser): + """ + Subclass of the standard `argparse.ArgumentParser` for remote interaction. + + Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` + instance passed to it during initialization. + Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a + connected client. + Finally, it offers some convenience methods and makes use of custom exceptions. + """ + + @staticmethod + def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: + """ + Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument. + + Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not + as convenient, when making use of sub-parsers. + + Args: + terminal_width: + The number of columns of the terminal to which to adjust help formatting. + base_cls (optional): + The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used. + + Returns: + The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. + """ + if base_cls is None: + base_cls = ArgumentDefaultsHelpFormatter + + class ClientHelpFormatter(base_cls): + def __init__(self, *args, **kwargs) -> None: + kwargs['width'] = terminal_width + super().__init__(*args, **kwargs) + return ClientHelpFormatter + + def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, + **kwargs) -> None: + """ + Sets additional internal attributes depending on whether a parent-parser was defined. + + The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword. + By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). + + Args: + stream_writer: + The instance of the `asyncio.StreamWriter` to use for message output. + terminal_width (optional): + The terminal width to assume for all message formatting. Defaults to `shutil.get_terminal_size`. + **kwargs(optional): + In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of + the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument + is empty. + """ + self._stream_writer: StreamWriter = stream_writer + self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns + kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) + kwargs.setdefault('exit_on_error', False) + super().__init__(**kwargs) + self._flags: Set[str] = set() + self._commands = None + + def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT, + **subparser_kwargs) -> 'ControlParser': + """ + Takes a function along with its parameters and adds a corresponding (sub-)command to the parser. + + The `add_subparsers` method must have been called prior to this. + + NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. + This method works correctly with any public method of the `SimpleTaskPool` class. + + Args: + function: + The reference to the function to be "converted" to a parser command. + omit_params (optional): + Names of function parameters not to add as parser arguments. + **subparser_kwargs (optional): + Passed directly to the `add_parser` method. + + Returns: + The subparser instance created from the function. + """ + subparser_kwargs.setdefault(NAME, function.__name__.replace('_', '-')) + subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME]) + subparser_kwargs.setdefault(HELP, get_first_doc_line(function)) + subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP]) + subparser: ControlParser = self._commands.add_parser(**subparser_kwargs) + subparser.add_function_args(function, omit_params) + return subparser + + def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser': + """ + Same as the `add_function_command` method, but for properties. + + Args: + prop: + The reference to the property to be "converted" to a parser command. + cls_name (optional): + Name of the class the property is defined on to appear in the command help text. + **subparser_kwargs (optional): + Passed directly to the `add_parser` method. + + Returns: + The subparser instance created from the property. + """ + subparser_kwargs.setdefault(NAME, prop.fget.__name__.replace('_', '-')) + subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME]) + getter_help = get_first_doc_line(prop.fget) + if prop.fset is None: + subparser_kwargs.setdefault(HELP, getter_help) + else: + subparser_kwargs.setdefault(HELP, f"Get/set the `{cls_name}.{subparser_kwargs[NAME]}` property") + subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP]) + subparser: ControlParser = self._commands.add_parser(**subparser_kwargs) + if prop.fset is not None: + _, param = signature(prop.fset).parameters.values() + setter_arg_help = f"If provided: {get_first_doc_line(prop.fset)} If omitted: {getter_help}" + subparser.add_function_arg(param, nargs='?', default=SUPPRESS, help=setter_arg_help) + return subparser + + def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (), + member_arg_name: str = CMD) -> ParsersDict: + """ + Takes a class and adds its methods and properties as (sub-)commands to the parser. + + The `add_subparsers` method must have been called prior to this. + + NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments. + This method works correctly with the `SimpleTaskPool` class. + + Args: + cls: + The reference to the class whose methods/properties are to be "converted" to parser commands. + public_only (optional): + If `False`, protected and private members are considered as well. `True` by default. + omit_members (optional): + Names of functions/properties not to add as parser commands. + member_arg_name (optional): + After parsing the arguments, depending on which command was invoked by the user, the corresponding + method/property will be stored as an extra argument in the parsed namespace under this attribute name. + Defaults to `constants.CMD`. + + Returns: + Dictionary mapping class member names to the (sub-)parsers created from them. + """ + parsers: ParsersDict = {} + common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width} + for name, member in getmembers(cls): + if name in omit_members or (name.startswith('_') and public_only): + continue + if isfunction(member): + subparser = self.add_function_command(member, **common_kwargs) + elif isinstance(member, property): + subparser = self.add_property_command(member, cls.__name__, **common_kwargs) + else: + continue + subparser.set_defaults(**{member_arg_name: member}) + parsers[name] = subparser + return parsers + + def add_subparsers(self, *args, **kwargs): + """Adds the subparsers action as an internal attribute before returning it.""" + self._commands = super().add_subparsers(*args, **kwargs) + return self._commands + + def _print_message(self, message: str, *args, **kwargs) -> None: + """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" + if message: + self._stream_writer.write(message.encode()) + + def exit(self, status: int = 0, message: str = None) -> None: + """This is overridden to prevent system exit to be invoked.""" + if message: + self._print_message(message) + + def error(self, message: str) -> None: + """This just adds the custom `HelpRequested` exception after the parent class' method.""" + super().error(message=message) + raise HelpRequested + + def print_help(self, file=None) -> None: + """This just adds the custom `HelpRequested` exception after the parent class' method.""" + super().print_help(file) + raise HelpRequested + + def add_function_arg(self, parameter: Parameter, **kwargs) -> Action: + """ + Takes an `inspect.Parameter` of a function and adds a corresponding argument to the parser. + + NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. + This method works correctly with any parameter of any public method of the `SimpleTaskPool` class. + + Args: + parameter: The `inspect.Parameter` object to be converted to a parser argument. + **kwargs: Passed to the `add_argument` method of the base class. + + Returns: + The `argparse.Action` returned by the `add_argument` method. + """ + if parameter.default is Parameter.empty: + # A non-optional function parameter should correspond to a positional argument. + name_or_flags = [parameter.name] + else: + flag = None + long = f'--{parameter.name.replace("_", "-")}' + # We try to generate a short version (flag) for the argument. + letter = parameter.name[0] + if letter not in self._flags: + flag = f'-{letter}' + self._flags.add(letter) + elif letter.upper() not in self._flags: + flag = f'-{letter.upper()}' + self._flags.add(letter.upper()) + name_or_flags = [long] if flag is None else [flag, long] + if parameter.annotation is bool: + # If we are dealing with a boolean parameter, always use the 'store_true' action. + # Even if the parameter's default value is `True`, this will make the parser argument's default `False`. + kwargs.setdefault('action', 'store_true') + else: + # For now, any other type annotation will implicitly use the default action 'store'. + # In addition, we always set the default value. + kwargs.setdefault('default', parameter.default) + if parameter.kind == Parameter.VAR_POSITIONAL: + # This is to be able to later unpack an arbitrary number of positional arguments. + kwargs.setdefault('nargs', '*') + if not kwargs.get('action') == 'store_true': + # The lambda wrapper around the type annotation is to avoid ValueError being raised on suppressed arguments. + # See: https://bugs.python.org/issue36078 + kwargs.setdefault('type', get_arg_type_wrapper(parameter.annotation)) + return self.add_argument(*name_or_flags, **kwargs) + + def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None: + """ + Takes a function reference and adds its parameters as arguments to the parser. + + NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. + This method works correctly with any public method of the `SimpleTaskPool` class. + + Args: + function: + The function whose parameters are to be converted to parser arguments. + Its parameters must be properly annotated. + omit (optional): + Names of function parameters not to add as parser arguments. + """ + for param in signature(function).parameters.values(): + if param.name not in omit: + # TODO: Look into parsing docstrings properly to try and extract argument help text. + # For now, the argument help just shows the type it will be converted to. + self.add_function_arg(param, help=repr(param.annotation)) + + +def get_arg_type_wrapper(cls: Type) -> Callable: + def wrapper(arg): + return arg if arg is SUPPRESS else cls(arg) + return wrapper diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py index 97f7982..0e733ac 100644 --- a/src/asyncio_taskpool/session.py +++ b/src/asyncio_taskpool/session.py @@ -15,21 +15,22 @@ You should have received a copy of the GNU Lesser General Public License along w If not, see .""" __doc__ = """ -This module contains the the definition of the control session class used by the control server. +This module contains the the definition of the `ControlSession` class used by the control server. """ import logging import json -from argparse import ArgumentError, HelpFormatter +from argparse import ArgumentError from asyncio.streams import StreamReader, StreamWriter -from typing import Callable, Optional, Type, Union, TYPE_CHECKING +from inspect import isfunction, signature +from typing import Callable, Optional, Union, TYPE_CHECKING -from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO -from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass -from .helpers import get_first_doc_line, return_or_exception, tasks_str -from .pool import BaseTaskPool, TaskPool, SimpleTaskPool -from .session_parser import CommandParser, NUM +from .constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER +from .exceptions import HelpRequested +from .helpers import return_or_exception +from .pool import TaskPool, SimpleTaskPool +from .parser import ControlParser if TYPE_CHECKING: from .server import ControlServer @@ -67,273 +68,88 @@ class ControlSession: self._client_class_name = server.client_class_name self._reader: StreamReader = reader self._writer: StreamWriter = writer - self._parser: Optional[CommandParser] = None - self._subparsers = None + self._parser: Optional[ControlParser] = None - def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, - **kwargs) -> CommandParser: + async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None: """ - Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance. + Takes a pool method reference, executes it, and writes a response accordingly. - Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument. + If the first parameter is named `self`, the method will be called with the `_pool` instance as its first + positional argument. If it returns nothing, the response upon successful execution will be `constants.CMD_OK`, + otherwise the response written to the stream will be its return value (as an encoded string). Args: - name: - The command name; passed directly into the `add_parser` method. - prog (optional): - Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set - equal to the `name` argument. - short_help (optional): - Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the - `long_help` argument is present; in that case the `long_help` argument is passed as `help`. - long_help (optional): - Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and - the `short_help` argument is present; in that case the `short_help` argument is passed as `description`. + prop: + The reference to the method defined on the `_pool` instance's class. **kwargs (optional): - Any keyword-arguments to directly pass into the `add_parser` method. - - Returns: - An instance of the `CommandParser` class representing the newly added control command. + Must correspond to the arguments expected by the `method`. + Correctly unpacks arbitrary-length positional and keyword-arguments. """ - if prog is None: - prog = name - kwargs.setdefault('help', short_help or long_help) - kwargs.setdefault('description', long_help or short_help) - return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs) + log.warning("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__) + normal_pos, var_pos = [], [] + for param in signature(method).parameters.values(): + if param.name == 'self': + normal_pos.append(self._pool) + elif param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): + normal_pos.append(kwargs.pop(param.name)) + elif param.kind == param.VAR_POSITIONAL: + var_pos = kwargs.pop(param.name) + output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs) + self._writer.write(CMD_OK if output is None else str(output).encode()) - def _add_base_commands(self) -> None: + async def _exec_property_and_respond(self, prop: property, **kwargs) -> None: """ - Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled. + Takes a pool property reference, executes its setter or getter, and writes a response accordingly. - These include commands mapping to the following pool methods: - - __str__ - - pool_size (get/set property) - - is_locked - - lock & unlock - - num_running - """ - cls: Type[BaseTaskPool] = self._pool.__class__ - self._add_command(CMD.NAME, short_help=get_first_doc_line(cls.__str__)) - self._add_command( - CMD.POOL_SIZE, - short_help="Get/set the maximum number of tasks in the pool.", - formatter_class=HelpFormatter - ).add_optional_num_argument( - default=None, - help=f"If passed a number: {get_first_doc_line(cls.pool_size.fset)} " - f"If omitted: {get_first_doc_line(cls.pool_size.fget)}" - ) - self._add_command(CMD.IS_LOCKED, short_help=get_first_doc_line(cls.is_locked.fget)) - self._add_command(CMD.LOCK, short_help=get_first_doc_line(cls.lock)) - self._add_command(CMD.UNLOCK, short_help=get_first_doc_line(cls.unlock)) - self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(cls.num_running.fget)) - self._add_command(CMD.NUM_CANCELLATIONS, short_help=get_first_doc_line(cls.num_cancellations.fget)) - self._add_command(CMD.NUM_ENDED, short_help=get_first_doc_line(cls.num_ended.fget)) - self._add_command(CMD.NUM_FINISHED, short_help=get_first_doc_line(cls.num_finished.fget)) - self._add_command(CMD.IS_FULL, short_help=get_first_doc_line(cls.is_full.fget)) - self._add_command( - CMD.GET_GROUP_IDS, short_help=get_first_doc_line(cls.get_group_ids) - ).add_argument( - 'group_name', - nargs='*', - help="Must be a name of a task group that exists within the pool." - ) - - def _add_simple_commands(self) -> None: - """ - Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled. - - These include commands mapping to the following pool methods: - - start - - stop - - stop_all - - func_name - """ - self._add_command( - CMD.START, short_help=get_first_doc_line(self._pool.__class__.start) - ).add_optional_num_argument( - help="Number of tasks to start." - ) - self._add_command( - CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop) - ).add_optional_num_argument( - help="Number of tasks to stop." - ) - self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) - self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)) - - def _add_advanced_commands(self) -> None: - """ - Adds the commands that are only supported, if a `TaskPool` object is controlled. - - These include commands mapping to the following pool methods: - - ... - """ - raise NotImplementedError - - def _init_parser(self, client_terminal_width: int) -> None: - """ - Initializes and fully configures the `CommandParser` responsible for handling the input. - - Depending on what specific task pool class is controlled by the server, different commands are added. + The property set/get method will always be called with the `_pool` instance as its first positional argument. Args: - client_terminal_width: - The number of columns of the client's terminal to be able to nicely format messages from the parser. + prop: + The reference to the property defined on the `_pool` instance's class. + **kwargs (optional): + If not empty, the property setter is executed and the keyword arguments are passed along to it; the + response upon successful execution will be `constants.CMD_OK`. Otherwise the property getter is + executed and the response written to the stream will be its return value (as an encoded string). """ - parser_kwargs = { - 'prog': '', - SESSION_WRITER: self._writer, - CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, - } - self._parser = CommandParser(**parser_kwargs) - self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) - self._add_base_commands() - if isinstance(self._pool, TaskPool): - self._add_advanced_commands() - elif isinstance(self._pool, SimpleTaskPool): - self._add_simple_commands() - elif isinstance(self._pool, BaseTaskPool): - raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}") + if kwargs: + log.warning("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__) + await return_or_exception(prop.fset, self._pool, **kwargs) + self._writer.write(CMD_OK) else: - raise NotATaskPool(f"Not a task pool instance: {self._pool}") + log.warning("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__) + self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode()) async def client_handshake(self) -> None: """ This method must be invoked before starting any other client interaction. - Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured. + Client info is retrieved, server info is sent back, and the `ControlParser` is initialized and configured. """ client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) log.debug("%s connected", self._client_class_name) - self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH]) + parser_kwargs = { + STREAM_WRITER: self._writer, + CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH], + 'prog': '', + 'usage': f'%(prog)s [-h] [{CMD}] ...' + } + self._parser = ControlParser(**parser_kwargs) + self._parser.add_subparsers(title="Commands", + metavar="(A command followed by '-h' or '--help' will show command-specific help.)") + self._parser.add_class_commands(self._pool.__class__) self._writer.write(str(self._pool).encode()) await self._writer.drain() - async def _write_function_output(self, func: Callable, *args, **kwargs) -> None: - """ - Acts as a wrapper around a call to a specific task pool method. - - The method is called and any exception is caught and saved. If there is no output and no exception caught, a - generic confirmation message is sent back to the client. Otherwise the output or a string representation of - the exception caught is sent back. - - Args: - func: - Reference to the task pool method. - *args (optional): - Any positional arguments to call the method with. - *+kwargs (optional): - Any keyword-arguments to call the method with. - """ - output = await return_or_exception(func, *args, **kwargs) - self._writer.write(b"ok" if output is None else str(output).encode()) - - async def _cmd_name(self, **_kwargs) -> None: - """Maps to the `__str__` method of any task pool class.""" - log.debug("%s requests task pool name", self._client_class_name) - await self._write_function_output(self._pool.__class__.__str__, self._pool) - - async def _cmd_pool_size(self, **kwargs) -> None: - """Maps to the `pool_size` property of any task pool class.""" - num = kwargs.get(NUM) - if num is None: - log.debug("%s requests pool size", self._client_class_name) - await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) - else: - log.debug("%s requests setting pool size to %s", self._client_class_name, num) - await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num) - - async def _cmd_is_locked(self, **_kwargs) -> None: - """Maps to the `is_locked` property of any task pool class.""" - log.debug("%s checks locked status", self._client_class_name) - await self._write_function_output(self._pool.__class__.is_locked.fget, self._pool) - - async def _cmd_lock(self, **_kwargs) -> None: - """Maps to the `lock` method of any task pool class.""" - log.debug("%s requests locking the pool", self._client_class_name) - await self._write_function_output(self._pool.lock) - - async def _cmd_unlock(self, **_kwargs) -> None: - """Maps to the `unlock` method of any task pool class.""" - log.debug("%s requests unlocking the pool", self._client_class_name) - await self._write_function_output(self._pool.unlock) - - async def _cmd_num_running(self, **_kwargs) -> None: - """Maps to the `num_running` property of any task pool class.""" - log.debug("%s requests number of running tasks", self._client_class_name) - await self._write_function_output(self._pool.__class__.num_running.fget, self._pool) - - async def _cmd_num_cancellations(self, **_kwargs) -> None: - """Maps to the `num_cancellations` property of any task pool class.""" - log.debug("%s requests number of cancelled tasks", self._client_class_name) - await self._write_function_output(self._pool.__class__.num_cancellations.fget, self._pool) - - async def _cmd_num_ended(self, **_kwargs) -> None: - """Maps to the `num_ended` property of any task pool class.""" - log.debug("%s requests number of ended tasks", self._client_class_name) - await self._write_function_output(self._pool.__class__.num_ended.fget, self._pool) - - async def _cmd_num_finished(self, **_kwargs) -> None: - """Maps to the `num_finished` property of any task pool class.""" - log.debug("%s requests number of finished tasks", self._client_class_name) - await self._write_function_output(self._pool.__class__.num_finished.fget, self._pool) - - async def _cmd_is_full(self, **_kwargs) -> None: - """Maps to the `is_full` property of any task pool class.""" - log.debug("%s checks full status", self._client_class_name) - await self._write_function_output(self._pool.__class__.is_full.fget, self._pool) - - async def _cmd_get_group_ids(self, **kwargs) -> None: - """Maps to the `get_group_ids` method of any task pool class.""" - log.debug("%s requests task ids for groups %s", self._client_class_name, kwargs['group_name']) - await self._write_function_output(self._pool.get_group_ids, *kwargs['group_name']) - - async def _cmd_start(self, **kwargs) -> None: - """Maps to the `start` method of the `SimpleTaskPool` class.""" - num = kwargs[NUM] - log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num)) - await self._write_function_output(self._pool.start, num) - - async def _cmd_stop(self, **kwargs) -> None: - """Maps to the `stop` method of the `SimpleTaskPool` class.""" - num = kwargs[NUM] - log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num)) - await self._write_function_output(self._pool.stop, num) - - async def _cmd_stop_all(self, **_kwargs) -> None: - """Maps to the `stop_all` method of the `SimpleTaskPool` class.""" - log.debug("%s requests stopping all tasks", self._client_class_name) - await self._write_function_output(self._pool.stop_all) - - async def _cmd_func_name(self, **_kwargs) -> None: - """Maps to the `func_name` method of the `SimpleTaskPool` class.""" - log.debug("%s requests pool function name", self._client_class_name) - await self._write_function_output(self._pool.__class__.func_name.fget, self._pool) - - async def _execute_command(self, **kwargs) -> None: - """ - Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it. - - Args: - **kwargs: - Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is - simply passed into the method determined from the command name. - """ - method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}') - await method(**kwargs) - async def _parse_command(self, msg: str) -> None: """ Takes a message from the client and attempts to parse it. If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the - `CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire - dictionary of keyword-arguments returned by the `CommandParser` passed into it. + `ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the entire + dictionary of keyword-arguments returned by the `ControlParser` passed into it. Args: - msg: - The non-empty string read from the client stream. + msg: The non-empty string read from the client stream. """ try: kwargs = vars(self._parser.parse_args(msg.split(' '))) @@ -342,7 +158,11 @@ class ControlSession: return except HelpRequested: return - await self._execute_command(**kwargs) + command = kwargs.pop(CMD) + if isfunction(command): + await self._exec_method_and_respond(command, **kwargs) + elif isinstance(command, property): + await self._exec_property_and_respond(command, **kwargs) async def listen(self) -> None: """ diff --git a/src/asyncio_taskpool/session_parser.py b/src/asyncio_taskpool/session_parser.py deleted file mode 100644 index d27e0cc..0000000 --- a/src/asyncio_taskpool/session_parser.py +++ /dev/null @@ -1,127 +0,0 @@ -__author__ = "Daniil Fajnberg" -__copyright__ = "Copyright © 2022 Daniil Fajnberg" -__license__ = """GNU LGPLv3.0 - -This file is part of asyncio-taskpool. - -asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of -version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. - -asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. -See the GNU Lesser General Public License for more details. - -You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. -If not, see .""" - -__doc__ = """ -This module contains the the definition of the `CommandParser` class used in a control server session. -""" - - -from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter -from asyncio.streams import StreamWriter -from typing import Type, TypeVar - -from .constants import SESSION_WRITER, CLIENT_INFO -from .exceptions import HelpRequested - - -FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) -FORMATTER_CLASS = 'formatter_class' -NUM = 'num' - - -class CommandParser(ArgumentParser): - """ - Subclass of the standard `argparse.ArgumentParser` for remote interaction. - - Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` - instance passed to it during initialization. - Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a - connected client. - Finally, it offers some convenience methods and makes use of custom exceptions. - """ - - @staticmethod - def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: - """ - Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument. - - Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not - as convenient, when making use of sub-parsers. - - Args: - terminal_width: - The number of columns of the terminal to which to adjust help formatting. - base_cls (optional): - The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used. - - Returns: - The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. - """ - if base_cls is None: - base_cls = ArgumentDefaultsHelpFormatter - - class ClientHelpFormatter(base_cls): - def __init__(self, *args, **kwargs) -> None: - kwargs['width'] = terminal_width - super().__init__(*args, **kwargs) - return ClientHelpFormatter - - def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None: - """ - Sets additional internal attributes depending on whether a parent-parser was defined. - - The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword. - By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). - - Args: - parent (optional): - An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method - of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present, - the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`. - **kwargs(optional): - In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of - the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument - is empty. - """ - self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER) - self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH) - kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) - kwargs.setdefault('exit_on_error', False) - super().__init__(**kwargs) - - @property - def session_writer(self) -> StreamWriter: - """Returns the predefined stream writer object of the control session.""" - return self._session_writer - - @property - def terminal_width(self) -> int: - """Returns the predefined terminal width.""" - return self._terminal_width - - def _print_message(self, message: str, *args, **kwargs) -> None: - """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" - if message: - self._session_writer.write(message.encode()) - - def exit(self, status: int = 0, message: str = None) -> None: - """This is overridden to prevent system exit to be invoked.""" - if message: - self._print_message(message) - - def print_help(self, file=None) -> None: - """This just adds the custom `HelpRequested` exception after the parent class' method.""" - super().print_help(file) - raise HelpRequested - - def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action: - """Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified.""" - if not name_or_flags: - name_or_flags = (NUM, ) - kwargs.setdefault('nargs', '?') - kwargs.setdefault('default', 1) - kwargs.setdefault('type', int) - return self.add_argument(*name_or_flags, **kwargs) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index ee41d1e..76f3f90 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -87,14 +87,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase): self.assertIsNone(await helpers.join_queue(mock_queue)) mock_join.assert_awaited_once_with() - def test_task_str(self): - self.assertEqual("task", helpers.tasks_str(1)) - self.assertEqual("tasks", helpers.tasks_str(0)) - self.assertEqual("tasks", helpers.tasks_str(-1)) - self.assertEqual("tasks", helpers.tasks_str(2)) - self.assertEqual("tasks", helpers.tasks_str(-10)) - self.assertEqual("tasks", helpers.tasks_str(42)) - def test_get_first_doc_line(self): expected_output = 'foo bar baz' mock_obj = MagicMock(__doc__=f"""{expected_output} diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..04b39bb --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,259 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +Unittests for the `asyncio_taskpool.control.parser` module. +""" + + +from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS +from inspect import signature +from unittest import TestCase +from unittest.mock import MagicMock, call, patch + +from asyncio_taskpool import parser +from asyncio_taskpool.exceptions import HelpRequested + + +FOO, BAR = 'foo', 'bar' + + +class ControlServerTestCase(TestCase): + + def setUp(self) -> None: + self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') + self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() + self.mock_help_formatter_factory.return_value = RawTextHelpFormatter + self.stream_writer, self.terminal_width = MagicMock(), 420 + self.kwargs = { + 'stream_writer': self.stream_writer, + 'terminal_width': self.terminal_width, + parser.FORMATTER_CLASS: FOO + } + self.parser = parser.ControlParser(**self.kwargs) + + def tearDown(self) -> None: + self.help_formatter_factory_patcher.stop() + + def test_help_formatter_factory(self): + self.help_formatter_factory_patcher.stop() + + class MockBaseClass(HelpFormatter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + terminal_width = 123456789 + cls = parser.ControlParser.help_formatter_factory(terminal_width, MockBaseClass) + self.assertTrue(issubclass(cls, MockBaseClass)) + instance = cls('prog') + self.assertEqual(terminal_width, getattr(instance, '_width')) + + cls = parser.ControlParser.help_formatter_factory(terminal_width) + self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter)) + instance = cls('prog') + self.assertEqual(terminal_width, getattr(instance, '_width')) + + def test_init(self): + self.assertIsInstance(self.parser, ArgumentParser) + self.assertEqual(self.stream_writer, self.parser._stream_writer) + self.assertEqual(self.terminal_width, self.parser._terminal_width) + self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) + self.assertFalse(getattr(self.parser, 'exit_on_error')) + self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) + self.assertSetEqual(set(), self.parser._flags) + self.assertIsNone(self.parser._commands) + + @patch.object(parser, 'get_first_doc_line') + def test_add_function_command(self, mock_get_first_doc_line: MagicMock): + def foo_bar(): pass + mock_subparser = MagicMock() + mock_add_parser = MagicMock(return_value=mock_subparser) + self.parser._commands = MagicMock(add_parser=mock_add_parser) + mock_get_first_doc_line.return_value = mock_help = 'help 123' + kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} + expected_name = 'foo-bar' + expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs + to_omit = ['abc', 'xyz'] + output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs) + self.assertEqual(mock_subparser, output) + mock_add_parser.assert_called_once_with(**expected_kwargs) + mock_subparser.add_function_args.assert_called_once_with(foo_bar, to_omit) + + @patch.object(parser, 'get_first_doc_line') + def test_add_property_command(self, mock_get_first_doc_line: MagicMock): + def get_prop(_self): pass + def set_prop(_self, _value): pass + prop = property(get_prop) + mock_subparser = MagicMock() + mock_add_parser = MagicMock(return_value=mock_subparser) + self.parser._commands = MagicMock(add_parser=mock_add_parser) + mock_get_first_doc_line.return_value = mock_help = 'help 123' + kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} + expected_name = 'get-prop' + expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs + output = self.parser.add_property_command(prop, **kwargs) + self.assertEqual(mock_subparser, output) + mock_get_first_doc_line.assert_called_once_with(get_prop) + mock_add_parser.assert_called_once_with(**expected_kwargs) + mock_subparser.add_function_arg.assert_not_called() + + mock_get_first_doc_line.reset_mock() + mock_add_parser.reset_mock() + + prop = property(get_prop, set_prop) + expected_help = f"Get/set the `.{expected_name}` property" + expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs + output = self.parser.add_property_command(prop, **kwargs) + self.assertEqual(mock_subparser, output) + mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)]) + mock_add_parser.assert_called_once_with(**expected_kwargs) + mock_subparser.add_function_arg.assert_called_once_with( + tuple(signature(set_prop).parameters.values())[1], + nargs='?', + default=SUPPRESS, + help=f"If provided: {mock_help} If omitted: {mock_help}" + ) + + @patch.object(parser.ControlParser, 'add_property_command') + @patch.object(parser.ControlParser, 'add_function_command') + def test_add_class_commands(self, mock_add_function_command: MagicMock, mock_add_property_command: MagicMock): + class FooBar: + some_attribute = None + + def _protected(self, _): pass + + def __private(self, _): pass + + def to_omit(self, _): pass + + def method(self, _): pass + + @property + def prop(self): return None + + mock_set_defaults = MagicMock() + mock_subparser = MagicMock(set_defaults=mock_set_defaults) + mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser + x = 'x' + common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer, + parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width} + expected_output = {'method': mock_subparser, 'prop': mock_subparser} + output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x) + self.assertDictEqual(expected_output, output) + mock_add_function_command.assert_called_once_with(FooBar.method, **common_kwargs) + mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs) + mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})]) + + def test__print_message(self): + self.stream_writer.write = MagicMock() + self.assertIsNone(self.parser._print_message('')) + self.stream_writer.write.assert_not_called() + msg = 'foo bar baz' + self.assertIsNone(self.parser._print_message(msg)) + self.stream_writer.write.assert_called_once_with(msg.encode()) + + @patch.object(parser.ControlParser, '_print_message') + def test_exit(self, mock__print_message: MagicMock): + self.assertIsNone(self.parser.exit(123, '')) + mock__print_message.assert_not_called() + msg = 'foo bar baz' + self.assertIsNone(self.parser.exit(123, msg)) + mock__print_message.assert_called_once_with(msg) + + @patch.object(parser.ArgumentParser, 'error') + def test_error(self, mock_supercls_error: MagicMock): + with self.assertRaises(HelpRequested): + self.parser.error(FOO + BAR) + mock_supercls_error.assert_called_once_with(message=FOO + BAR) + + @patch.object(parser.ArgumentParser, 'print_help') + def test_print_help(self, mock_print_help: MagicMock): + arg = MagicMock() + with self.assertRaises(HelpRequested): + self.parser.print_help(arg) + mock_print_help.assert_called_once_with(arg) + + @patch.object(parser, 'get_arg_type_wrapper') + @patch.object(parser.ArgumentParser, 'add_argument') + def test_add_function_arg(self, mock_add_argument: MagicMock, mock_get_arg_type_wrapper: MagicMock): + mock_add_argument.return_value = expected_output = 'action' + mock_get_arg_type_wrapper.return_value = mock_type = 'fake' + + foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex + bar_default, baz_default, boo_default = 1, 0.1, 1j + + def func(foo: foo_type, *args: args_type, bar: bar_type = bar_default, baz: baz_type = baz_default, + boo: boo_type = boo_default, flag: bool = False): + return foo, args, bar, baz, boo, flag + + param_foo, param_args, param_bar, param_baz, param_boo, param_flag = signature(func).parameters.values() + kwargs = {FOO + BAR: 'xyz'} + self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs)) + mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs) + mock_get_arg_type_wrapper.assert_called_once_with(foo_type) + + mock_add_argument.reset_mock() + mock_get_arg_type_wrapper.reset_mock() + + self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs)) + mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs) + mock_get_arg_type_wrapper.assert_called_once_with(args_type) + + mock_add_argument.reset_mock() + mock_get_arg_type_wrapper.reset_mock() + + self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs)) + mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs) + mock_get_arg_type_wrapper.assert_called_once_with(bar_type) + + mock_add_argument.reset_mock() + mock_get_arg_type_wrapper.reset_mock() + + self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs)) + mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs) + mock_get_arg_type_wrapper.assert_called_once_with(baz_type) + + mock_add_argument.reset_mock() + mock_get_arg_type_wrapper.reset_mock() + + self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs)) + mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs) + mock_get_arg_type_wrapper.assert_called_once_with(boo_type) + + mock_add_argument.reset_mock() + mock_get_arg_type_wrapper.reset_mock() + + self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs)) + mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs) + mock_get_arg_type_wrapper.assert_not_called() + + @patch.object(parser.ControlParser, 'add_function_arg') + def test_add_function_args(self, mock_add_function_arg: MagicMock): + def func(foo: str, *args: int, bar: float = 0.1): + return foo, args, bar + _, param_args, param_bar = signature(func).parameters.values() + self.assertIsNone(self.parser.add_function_args(func, omit=['foo'])) + mock_add_function_arg.assert_has_calls([ + call(param_args, help=repr(param_args.annotation)), + call(param_bar, help=repr(param_bar.annotation)), + ]) + + +class RestTestCase(TestCase): + def test_get_arg_type_wrapper(self): + type_wrap = parser.get_arg_type_wrapper(int) + self.assertEqual(SUPPRESS, type_wrap(SUPPRESS)) + self.assertEqual(13, type_wrap('13')) diff --git a/tests/test_session.py b/tests/test_session.py index 2a3eb1d..fec98d7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -25,9 +25,9 @@ from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch, call from asyncio_taskpool import session -from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER -from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass -from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool +from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER +from asyncio_taskpool.exceptions import HelpRequested +from asyncio_taskpool.pool import SimpleTaskPool FOO, BAR = 'foo', 'bar' @@ -61,236 +61,105 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): self.assertEqual(self.mock_reader, self.session._reader) self.assertEqual(self.mock_writer, self.session._writer) self.assertIsNone(self.session._parser) - self.assertIsNone(self.session._subparsers) - def test__add_command(self): - expected_output = 123456 - mock_add_parser = MagicMock(return_value=expected_output) - self.session._subparsers = MagicMock(add_parser=mock_add_parser) - self.session._parser = MagicMock() - name, prog, short_help, long_help = 'abc', None, 'short123', None - kwargs = {'x': 1, 'y': 2} - output = self.session._add_command(name, prog, short_help, long_help, **kwargs) - self.assertEqual(expected_output, output) - mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help, - parent=self.session._parser, **kwargs) + @patch.object(session, 'return_or_exception') + async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock): + def method(self, arg1, arg2, *var_args, **rest): pass + test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11} + kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest + mock_return_or_exception.return_value = None + self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs)) + mock_return_or_exception.assert_awaited_once_with( + method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest + ) + self.mock_writer.write.assert_called_once_with(session.CMD_OK) - mock_add_parser.reset_mock() + @patch.object(session, 'return_or_exception') + async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock): + def prop_get(_): pass + def prop_set(_): pass + prop = property(prop_get, prop_set) + kwargs = {'value': 'something'} + mock_return_or_exception.return_value = None + self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs)) + mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs) + self.mock_writer.write.assert_called_once_with(session.CMD_OK) - prog, long_help = 'ffffff', 'so long, wow' - output = self.session._add_command(name, prog, short_help, long_help, **kwargs) - self.assertEqual(expected_output, output) - mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help, - parent=self.session._parser, **kwargs) + mock_return_or_exception.reset_mock() + self.mock_writer.write.reset_mock() - mock_add_parser.reset_mock() + mock_return_or_exception.return_value = val = 420.69 + self.assertIsNone(await self.session._exec_property_and_respond(prop)) + mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool) + self.mock_writer.write.assert_called_once_with(str(val).encode()) - short_help = None - output = self.session._add_command(name, prog, short_help, long_help, **kwargs) - self.assertEqual(expected_output, output) - mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help, - parent=self.session._parser, **kwargs) - - @patch.object(session, 'get_first_doc_line') - @patch.object(session.ControlSession, '_add_command') - def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock): - self.assertIsNone(self.session._add_base_commands()) - mock__add_command.assert_called() - mock_get_first_doc_line.assert_called() - - mock__add_command.reset_mock() - mock_get_first_doc_line.reset_mock() - - self.assertIsNone(self.session._add_simple_commands()) - mock__add_command.assert_called() - mock_get_first_doc_line.assert_called() - - with self.assertRaises(NotImplementedError): - self.session._add_advanced_commands() - - @patch.object(session.ControlSession, '_add_simple_commands') - @patch.object(session.ControlSession, '_add_advanced_commands') - @patch.object(session.ControlSession, '_add_base_commands') - @patch.object(session, 'CommandParser') - def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock, - mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock): - mock_command_parser_cls.return_value = mock_parser = MagicMock() - self.session._pool = TaskPool() - width = 1234 - expected_parser_kwargs = { - 'prog': '', - SESSION_WRITER: self.mock_writer, - CLIENT_INFO.TERMINAL_WIDTH: width, - } - self.assertIsNone(self.session._init_parser(width)) - mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) - mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) - mock__add_base_commands.assert_called_once_with() - mock__add_advanced_commands.assert_called_once_with() - mock__add_simple_commands.assert_not_called() - - mock_command_parser_cls.reset_mock() - mock_parser.add_subparsers.reset_mock() - mock__add_base_commands.reset_mock() - mock__add_advanced_commands.reset_mock() - mock__add_simple_commands.reset_mock() - - async def fake_coroutine(): pass - - self.session._pool = SimpleTaskPool(fake_coroutine) - self.assertIsNone(self.session._init_parser(width)) - mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) - mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) - mock__add_base_commands.assert_called_once_with() - mock__add_advanced_commands.assert_not_called() - mock__add_simple_commands.assert_called_once_with() - - mock_command_parser_cls.reset_mock() - mock_parser.add_subparsers.reset_mock() - mock__add_base_commands.reset_mock() - mock__add_advanced_commands.reset_mock() - mock__add_simple_commands.reset_mock() - - class FakeTaskPool(BaseTaskPool): - pass - - self.session._pool = FakeTaskPool() - with self.assertRaises(UnknownTaskPoolClass): - self.session._init_parser(width) - mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) - mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) - mock__add_base_commands.assert_called_once_with() - mock__add_advanced_commands.assert_not_called() - mock__add_simple_commands.assert_not_called() - - mock_command_parser_cls.reset_mock() - mock_parser.add_subparsers.reset_mock() - mock__add_base_commands.reset_mock() - mock__add_advanced_commands.reset_mock() - mock__add_simple_commands.reset_mock() - - self.session._pool = MagicMock() - with self.assertRaises(NotATaskPool): - self.session._init_parser(width) - mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) - mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) - mock__add_base_commands.assert_called_once_with() - mock__add_advanced_commands.assert_not_called() - mock__add_simple_commands.assert_not_called() - - @patch.object(session.ControlSession, '_init_parser') - async def test_client_handshake(self, mock__init_parser: MagicMock): + @patch.object(session, 'ControlParser') + async def test_client_handshake(self, mock_parser_cls: MagicMock): + mock_add_subparsers, mock_add_class_commands = MagicMock(), MagicMock() + mock_parser = MagicMock(add_subparsers=mock_add_subparsers, add_class_commands=mock_add_class_commands) + mock_parser_cls.return_value = mock_parser width = 5678 msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' ' mock_read = AsyncMock(return_value=msg.encode()) self.mock_reader.read = mock_read self.mock_writer.drain = AsyncMock() + expected_parser_kwargs = { + STREAM_WRITER: self.mock_writer, + CLIENT_INFO.TERMINAL_WIDTH: width, + 'prog': '', + 'usage': f'%(prog)s [-h] [{CMD}] ...' + } + expected_subparsers_kwargs = { + 'title': "Commands", + 'metavar': "(A command followed by '-h' or '--help' will show command-specific help.)" + } self.assertIsNone(await self.session.client_handshake()) + self.assertEqual(mock_parser, self.session._parser) mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) - mock__init_parser.assert_called_once_with(width) + mock_parser_cls.assert_called_once_with(**expected_parser_kwargs) + mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs) + mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__) self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) self.mock_writer.drain.assert_awaited_once_with() - @patch.object(session, 'return_or_exception') - async def test__write_function_output(self, mock_return_or_exception: MagicMock): - self.mock_writer.write = MagicMock() - mock_return_or_exception.return_value = None - func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'} - self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) - mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) - self.mock_writer.write.assert_called_once_with(b"ok") - - mock_return_or_exception.reset_mock() - self.mock_writer.write.reset_mock() - - mock_return_or_exception.return_value = output = MagicMock() - self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) - mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) - self.mock_writer.write.assert_called_once_with(str(output).encode()) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_name(self, mock__write_function_output: AsyncMock): - self.assertIsNone(await self.session._cmd_name()) - mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock): - num = 12345 - kwargs = {session.NUM: num, FOO: BAR} - self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) - mock__write_function_output.assert_awaited_once_with( - self.mock_pool.__class__.pool_size.fset, self.session._pool, num - ) - - mock__write_function_output.reset_mock() - - kwargs.pop(session.NUM) - self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) - mock__write_function_output.assert_awaited_once_with( - self.mock_pool.__class__.pool_size.fget, self.session._pool - ) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_num_running(self, mock__write_function_output: AsyncMock): - self.assertIsNone(await self.session._cmd_num_running()) - mock__write_function_output.assert_awaited_once_with( - self.mock_pool.__class__.num_running.fget, self.session._pool - ) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_start(self, mock__write_function_output: AsyncMock): - num = 12345 - kwargs = {session.NUM: num, FOO: BAR} - self.assertIsNone(await self.session._cmd_start(**kwargs)) - mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_stop(self, mock__write_function_output: AsyncMock): - num = 12345 - kwargs = {session.NUM: num, FOO: BAR} - self.assertIsNone(await self.session._cmd_stop(**kwargs)) - mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock): - self.assertIsNone(await self.session._cmd_stop_all()) - mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all) - - @patch.object(session.ControlSession, '_write_function_output') - async def test__cmd_func_name(self, mock__write_function_output: AsyncMock): - self.assertIsNone(await self.session._cmd_func_name()) - mock__write_function_output.assert_awaited_once_with( - self.mock_pool.__class__.func_name.fget, self.session._pool - ) - - async def test__execute_command(self): - mock_method = AsyncMock() - cmd = 'this-is-a-test' - setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method) - kwargs = {FOO: BAR, 'hello': 'python'} - self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs)) - mock_method.assert_awaited_once_with(**kwargs) - - @patch.object(session.ControlSession, '_execute_command') - async def test__parse_command(self, mock__execute_command: AsyncMock): + @patch.object(session.ControlSession, '_exec_property_and_respond') + @patch.object(session.ControlSession, '_exec_method_and_respond') + async def test__parse_command(self, mock__exec_method_and_respond: AsyncMock, + mock__exec_property_and_respond: AsyncMock): + def method(_): pass + prop = property(method) msg = 'asdf asd as a' kwargs = {FOO: BAR, 'hello': 'python'} - mock_parse_args = MagicMock(return_value=Namespace(**kwargs)) + mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs)) self.session._parser = MagicMock(parse_args=mock_parse_args) self.mock_writer.write = MagicMock() self.assertIsNone(await self.session._parse_command(msg)) mock_parse_args.assert_called_once_with(msg.split(' ')) self.mock_writer.write.assert_not_called() - mock__execute_command.assert_awaited_once_with(**kwargs) + mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs) + mock__exec_property_and_respond.assert_not_called() - mock__execute_command.reset_mock() + mock__exec_method_and_respond.reset_mock() + mock_parse_args.reset_mock() + + mock_parse_args = MagicMock(return_value=Namespace(**{CMD: prop}, **kwargs)) + self.session._parser = MagicMock(parse_args=mock_parse_args) + self.mock_writer.write = MagicMock() + self.assertIsNone(await self.session._parse_command(msg)) + mock_parse_args.assert_called_once_with(msg.split(' ')) + self.mock_writer.write.assert_not_called() + mock__exec_method_and_respond.assert_not_called() + mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs) + + mock__exec_property_and_respond.reset_mock() mock_parse_args.reset_mock() mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") self.assertIsNone(await self.session._parse_command(msg)) mock_parse_args.assert_called_once_with(msg.split(' ')) self.mock_writer.write.assert_called_once_with(str(exc).encode()) - mock__execute_command.assert_not_awaited() + mock__exec_method_and_respond.assert_not_awaited() + mock__exec_property_and_respond.assert_not_awaited() self.mock_writer.write.reset_mock() mock_parse_args.reset_mock() @@ -299,7 +168,8 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): self.assertIsNone(await self.session._parse_command(msg)) mock_parse_args.assert_called_once_with(msg.split(' ')) self.mock_writer.write.assert_not_called() - mock__execute_command.assert_not_awaited() + mock__exec_method_and_respond.assert_not_awaited() + mock__exec_property_and_respond.assert_not_awaited() @patch.object(session.ControlSession, '_parse_command') async def test_listen(self, mock__parse_command: AsyncMock): diff --git a/tests/test_session_parser.py b/tests/test_session_parser.py deleted file mode 100644 index da7fff8..0000000 --- a/tests/test_session_parser.py +++ /dev/null @@ -1,134 +0,0 @@ -__author__ = "Daniil Fajnberg" -__copyright__ = "Copyright © 2022 Daniil Fajnberg" -__license__ = """GNU LGPLv3.0 - -This file is part of asyncio-taskpool. - -asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of -version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. - -asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. -See the GNU Lesser General Public License for more details. - -You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. -If not, see .""" - -__doc__ = """ -Unittests for the `asyncio_taskpool.session_parser` module. -""" - - -from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter -from unittest import IsolatedAsyncioTestCase -from unittest.mock import MagicMock, patch - -from asyncio_taskpool import session_parser -from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO -from asyncio_taskpool.exceptions import HelpRequested - - -FOO = 'foo' - - -class ControlServerTestCase(IsolatedAsyncioTestCase): - - def setUp(self) -> None: - self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory') - self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() - self.mock_help_formatter_factory.return_value = RawTextHelpFormatter - self.session_writer, self.terminal_width = MagicMock(), 420 - self.kwargs = { - SESSION_WRITER: self.session_writer, - CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width, - session_parser.FORMATTER_CLASS: FOO - } - self.parser = session_parser.CommandParser(**self.kwargs) - - def tearDown(self) -> None: - self.help_formatter_factory_patcher.stop() - - def test_help_formatter_factory(self): - self.help_formatter_factory_patcher.stop() - - class MockBaseClass(HelpFormatter): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - terminal_width = 123456789 - cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass) - self.assertTrue(issubclass(cls, MockBaseClass)) - instance = cls('prog') - self.assertEqual(terminal_width, getattr(instance, '_width')) - - cls = session_parser.CommandParser.help_formatter_factory(terminal_width) - self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter)) - instance = cls('prog') - self.assertEqual(terminal_width, getattr(instance, '_width')) - - def test_init(self): - self.assertIsInstance(self.parser, ArgumentParser) - self.assertEqual(self.session_writer, self.parser._session_writer) - self.assertEqual(self.terminal_width, self.parser._terminal_width) - self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) - self.assertFalse(getattr(self.parser, 'exit_on_error')) - self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) - - def test_session_writer(self): - self.assertEqual(self.session_writer, self.parser.session_writer) - - def test_terminal_width(self): - self.assertEqual(self.terminal_width, self.parser.terminal_width) - - def test__print_message(self): - self.session_writer.write = MagicMock() - self.assertIsNone(self.parser._print_message('')) - self.session_writer.write.assert_not_called() - msg = 'foo bar baz' - self.assertIsNone(self.parser._print_message(msg)) - self.session_writer.write.assert_called_once_with(msg.encode()) - - @patch.object(session_parser.CommandParser, '_print_message') - def test_exit(self, mock__print_message: MagicMock): - self.assertIsNone(self.parser.exit(123, '')) - mock__print_message.assert_not_called() - msg = 'foo bar baz' - self.assertIsNone(self.parser.exit(123, msg)) - mock__print_message.assert_called_once_with(msg) - - @patch.object(session_parser.ArgumentParser, 'print_help') - def test_print_help(self, mock_print_help: MagicMock): - arg = MagicMock() - with self.assertRaises(HelpRequested): - self.parser.print_help(arg) - mock_print_help.assert_called_once_with(arg) - - def test_add_optional_num_argument(self): - metavar = 'FOOBAR' - action = self.parser.add_optional_num_argument(metavar=metavar) - self.assertIsInstance(action, Action) - self.assertEqual('?', action.nargs) - self.assertEqual(1, action.default) - self.assertEqual(int, action.type) - self.assertEqual(metavar, action.metavar) - num = 111 - kwargs = vars(self.parser.parse_args([f'{num}'])) - self.assertDictEqual({session_parser.NUM: num}, kwargs) - - name = f'--{FOO}' - nargs = '+' - default = 1 - _type = float - required = True - dest = 'foo_bar' - action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type, - required=required, metavar=metavar, dest=dest) - self.assertIsInstance(action, Action) - self.assertEqual(nargs, action.nargs) - self.assertEqual(default, action.default) - self.assertEqual(_type, action.type) - self.assertEqual(required, action.required) - self.assertEqual(metavar, action.metavar) - self.assertEqual(dest, action.dest) - kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5'])) - self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)