diff --git a/setup.cfg b/setup.cfg index c0caba2..3be8096 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.3.3 +version = 0.3.4 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index 19d56c7..c4f0869 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -24,7 +24,7 @@ PACKAGE_NAME = 'asyncio_taskpool' CLIENT_EXIT = 'exit' SESSION_MSG_BYTES = 1024 * 100 -SESSION_PARSER_WRITER = 'session_writer' +SESSION_WRITER = 'session_writer' class CLIENT_INFO: diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py index 6db1d01..2216076 100644 --- a/src/asyncio_taskpool/session.py +++ b/src/asyncio_taskpool/session.py @@ -25,7 +25,7 @@ from argparse import ArgumentError, HelpFormatter from asyncio.streams import StreamReader, StreamWriter from typing import Callable, Optional, Union, TYPE_CHECKING -from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO +from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass from .helpers import get_first_doc_line, return_or_exception, tasks_str from .pool import BaseTaskPool, TaskPool, SimpleTaskPool @@ -166,7 +166,7 @@ class ControlSession: """ parser_kwargs = { 'prog': '', - SESSION_PARSER_WRITER: self._writer, + SESSION_WRITER: self._writer, CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, } self._parser = CommandParser(**parser_kwargs) diff --git a/src/asyncio_taskpool/session_parser.py b/src/asyncio_taskpool/session_parser.py index a315234..d27e0cc 100644 --- a/src/asyncio_taskpool/session_parser.py +++ b/src/asyncio_taskpool/session_parser.py @@ -1,8 +1,29 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +This module contains the the definition of the `CommandParser` class used in a control server session. +""" + + from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter from asyncio.streams import StreamWriter from typing import Type, TypeVar -from .constants import SESSION_PARSER_WRITER, CLIENT_INFO +from .constants import SESSION_WRITER, CLIENT_INFO from .exceptions import HelpRequested @@ -12,8 +33,33 @@ NUM = 'num' class CommandParser(ArgumentParser): + """ + Subclass of the standard `argparse.ArgumentParser` for remote interaction. + + Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` + instance passed to it during initialization. + Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a + connected client. + Finally, it offers some convenience methods and makes use of custom exceptions. + """ + @staticmethod def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: + """ + Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument. + + Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not + as convenient, when making use of sub-parsers. + + Args: + terminal_width: + The number of columns of the terminal to which to adjust help formatting. + base_cls (optional): + The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used. + + Returns: + The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. + """ if base_cls is None: base_cls = ArgumentDefaultsHelpFormatter @@ -23,35 +69,56 @@ class CommandParser(ArgumentParser): super().__init__(*args, **kwargs) return ClientHelpFormatter - def __init__(self, *args, **kwargs) -> None: - parent: CommandParser = kwargs.pop('parent', None) - self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER) + def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None: + """ + Sets additional internal attributes depending on whether a parent-parser was defined. + + The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword. + By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). + + Args: + parent (optional): + An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method + of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present, + the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`. + **kwargs(optional): + In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of + the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument + is empty. + """ + self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER) self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH) kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) kwargs.setdefault('exit_on_error', False) - super().__init__(*args, **kwargs) + super().__init__(**kwargs) @property - def stream_writer(self) -> StreamWriter: - return self._stream_writer + def session_writer(self) -> StreamWriter: + """Returns the predefined stream writer object of the control session.""" + return self._session_writer @property def terminal_width(self) -> int: + """Returns the predefined terminal width.""" return self._terminal_width def _print_message(self, message: str, *args, **kwargs) -> None: + """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" if message: - self.stream_writer.write(message.encode()) + self._session_writer.write(message.encode()) def exit(self, status: int = 0, message: str = None) -> None: + """This is overridden to prevent system exit to be invoked.""" if message: self._print_message(message) def print_help(self, file=None) -> None: + """This just adds the custom `HelpRequested` exception after the parent class' method.""" super().print_help(file) raise HelpRequested def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action: + """Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified.""" if not name_or_flags: name_or_flags = (NUM, ) kwargs.setdefault('nargs', '?') diff --git a/tests/test_session.py b/tests/test_session.py index 921e06d..2a3eb1d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -25,7 +25,7 @@ from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch, call from asyncio_taskpool import session -from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER +from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool @@ -119,7 +119,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): width = 1234 expected_parser_kwargs = { 'prog': '', - SESSION_PARSER_WRITER: self.mock_writer, + SESSION_WRITER: self.mock_writer, CLIENT_INFO.TERMINAL_WIDTH: width, } self.assertIsNone(self.session._init_parser(width)) diff --git a/tests/test_session_parser.py b/tests/test_session_parser.py new file mode 100644 index 0000000..da7fff8 --- /dev/null +++ b/tests/test_session_parser.py @@ -0,0 +1,134 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +Unittests for the `asyncio_taskpool.session_parser` module. +""" + + +from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter +from unittest import IsolatedAsyncioTestCase +from unittest.mock import MagicMock, patch + +from asyncio_taskpool import session_parser +from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO +from asyncio_taskpool.exceptions import HelpRequested + + +FOO = 'foo' + + +class ControlServerTestCase(IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory') + self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() + self.mock_help_formatter_factory.return_value = RawTextHelpFormatter + self.session_writer, self.terminal_width = MagicMock(), 420 + self.kwargs = { + SESSION_WRITER: self.session_writer, + CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width, + session_parser.FORMATTER_CLASS: FOO + } + self.parser = session_parser.CommandParser(**self.kwargs) + + def tearDown(self) -> None: + self.help_formatter_factory_patcher.stop() + + def test_help_formatter_factory(self): + self.help_formatter_factory_patcher.stop() + + class MockBaseClass(HelpFormatter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + terminal_width = 123456789 + cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass) + self.assertTrue(issubclass(cls, MockBaseClass)) + instance = cls('prog') + self.assertEqual(terminal_width, getattr(instance, '_width')) + + cls = session_parser.CommandParser.help_formatter_factory(terminal_width) + self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter)) + instance = cls('prog') + self.assertEqual(terminal_width, getattr(instance, '_width')) + + def test_init(self): + self.assertIsInstance(self.parser, ArgumentParser) + self.assertEqual(self.session_writer, self.parser._session_writer) + self.assertEqual(self.terminal_width, self.parser._terminal_width) + self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) + self.assertFalse(getattr(self.parser, 'exit_on_error')) + self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) + + def test_session_writer(self): + self.assertEqual(self.session_writer, self.parser.session_writer) + + def test_terminal_width(self): + self.assertEqual(self.terminal_width, self.parser.terminal_width) + + def test__print_message(self): + self.session_writer.write = MagicMock() + self.assertIsNone(self.parser._print_message('')) + self.session_writer.write.assert_not_called() + msg = 'foo bar baz' + self.assertIsNone(self.parser._print_message(msg)) + self.session_writer.write.assert_called_once_with(msg.encode()) + + @patch.object(session_parser.CommandParser, '_print_message') + def test_exit(self, mock__print_message: MagicMock): + self.assertIsNone(self.parser.exit(123, '')) + mock__print_message.assert_not_called() + msg = 'foo bar baz' + self.assertIsNone(self.parser.exit(123, msg)) + mock__print_message.assert_called_once_with(msg) + + @patch.object(session_parser.ArgumentParser, 'print_help') + def test_print_help(self, mock_print_help: MagicMock): + arg = MagicMock() + with self.assertRaises(HelpRequested): + self.parser.print_help(arg) + mock_print_help.assert_called_once_with(arg) + + def test_add_optional_num_argument(self): + metavar = 'FOOBAR' + action = self.parser.add_optional_num_argument(metavar=metavar) + self.assertIsInstance(action, Action) + self.assertEqual('?', action.nargs) + self.assertEqual(1, action.default) + self.assertEqual(int, action.type) + self.assertEqual(metavar, action.metavar) + num = 111 + kwargs = vars(self.parser.parse_args([f'{num}'])) + self.assertDictEqual({session_parser.NUM: num}, kwargs) + + name = f'--{FOO}' + nargs = '+' + default = 1 + _type = float + required = True + dest = 'foo_bar' + action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type, + required=required, metavar=metavar, dest=dest) + self.assertIsInstance(action, Action) + self.assertEqual(nargs, action.nargs) + self.assertEqual(default, action.default) + self.assertEqual(_type, action.type) + self.assertEqual(required, action.required) + self.assertEqual(metavar, action.metavar) + self.assertEqual(dest, action.dest) + kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5'])) + self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)