diff --git a/setup.cfg b/setup.cfg index c83c4fe..b5c3e59 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.6.3 +version = 0.6.4 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/control/client.py b/src/asyncio_taskpool/control/client.py index b779a97..20fe302 100644 --- a/src/asyncio_taskpool/control/client.py +++ b/src/asyncio_taskpool/control/client.py @@ -41,7 +41,7 @@ class ControlClient(ABC): """ @staticmethod - def client_info() -> dict: + def _client_info() -> dict: """Returns a dictionary of client information relevant for the handshake with the server.""" return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns} @@ -73,7 +73,7 @@ class ControlClient(ABC): writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method """ self._connected = True - writer.write(json.dumps(self.client_info()).encode()) + writer.write(json.dumps(self._client_info()).encode()) await writer.drain() print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) print("Type '-h' to get help and usage instructions for all available commands.\n") diff --git a/src/asyncio_taskpool/control/parser.py b/src/asyncio_taskpool/control/parser.py index 75352ed..48dcc3e 100644 --- a/src/asyncio_taskpool/control/parser.py +++ b/src/asyncio_taskpool/control/parser.py @@ -23,10 +23,10 @@ from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, Help from asyncio.streams import StreamWriter from inspect import Parameter, getmembers, isfunction, signature from shutil import get_terminal_size -from typing import Callable, Container, Dict, Set, Type, TypeVar +from typing import Any, Callable, Container, Dict, Set, Type, TypeVar from ..constants import CLIENT_INFO, CMD, STREAM_WRITER -from ..exceptions import HelpRequested +from ..exceptions import HelpRequested, ParserError from ..helpers import get_first_doc_line @@ -35,7 +35,6 @@ ParsersDict = Dict[str, 'ControlParser'] OMIT_PARAMS_DEFAULT = ('self', ) -FORMATTER_CLASS = 'formatter_class' NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description' @@ -79,24 +78,23 @@ class ControlParser(ArgumentParser): def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None: """ - Sets additional internal attributes depending on whether a parent-parser was defined. + Subclass of the `ArgumentParser` geared towards asynchronous interaction with an object "from the outside". - The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword. - By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). + Allows directing output to a specified writer rather than stdout/stderr and setting terminal width explicitly. Args: stream_writer: The instance of the `asyncio.StreamWriter` to use for message output. terminal_width (optional): - The terminal width to assume for all message formatting. Defaults to `shutil.get_terminal_size`. + The terminal width to use for all message formatting. Defaults to `shutil.get_terminal_size().columns`. **kwargs(optional): - In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of - the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument - is empty. + Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a + class is specified, it will always be subclassed in the `help_formatter_factory`. + Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). """ self._stream_writer: StreamWriter = stream_writer self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns - kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) + kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class')) kwargs.setdefault('exit_on_error', False) super().__init__(**kwargs) self._flags: Set[str] = set() @@ -219,7 +217,7 @@ class ControlParser(ArgumentParser): def error(self, message: str) -> None: """This just adds the custom `HelpRequested` exception after the parent class' method.""" super().error(message=message) - raise HelpRequested + raise ParserError def print_help(self, file=None) -> None: """This just adds the custom `HelpRequested` exception after the parent class' method.""" @@ -267,9 +265,8 @@ class ControlParser(ArgumentParser): # This is to be able to later unpack an arbitrary number of positional arguments. kwargs.setdefault('nargs', '*') if not kwargs.get('action') == 'store_true': - # The lambda wrapper around the type annotation is to avoid ValueError being raised on suppressed arguments. - # See: https://bugs.python.org/issue36078 - kwargs.setdefault('type', get_arg_type_wrapper(parameter.annotation)) + # Set the type from the parameter annotation. + kwargs.setdefault('type', _get_arg_type_wrapper(parameter.annotation)) return self.add_argument(*name_or_flags, **kwargs) def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None: @@ -293,7 +290,13 @@ class ControlParser(ArgumentParser): self.add_function_arg(param, help=repr(param.annotation)) -def get_arg_type_wrapper(cls: Type) -> Callable: - def wrapper(arg): - return arg if arg is SUPPRESS else cls(arg) +def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]: + """ + Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments. + + See: https://bugs.python.org/issue36078 + """ + def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg) + # Copy the name of the class to maintain useful help messages when incorrect arguments are passed. + wrapper.__name__ = cls.__name__ return wrapper diff --git a/src/asyncio_taskpool/control/server.py b/src/asyncio_taskpool/control/server.py index a506bf7..97286e2 100644 --- a/src/asyncio_taskpool/control/server.py +++ b/src/asyncio_taskpool/control/server.py @@ -125,6 +125,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta async def serve_forever(self) -> Task: """ This method actually starts the server and begins listening to client connections on the specified interface. + It should never block because the serving will be performed in a separate task. """ log.debug("Starting %s...", self.__class__.__name__) diff --git a/src/asyncio_taskpool/control/session.py b/src/asyncio_taskpool/control/session.py index 17cee99..a7aefda 100644 --- a/src/asyncio_taskpool/control/session.py +++ b/src/asyncio_taskpool/control/session.py @@ -27,7 +27,7 @@ from inspect import isfunction, signature from typing import Callable, Optional, Union, TYPE_CHECKING from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER -from ..exceptions import CommandError, HelpRequested +from ..exceptions import CommandError, HelpRequested, ParserError from ..helpers import return_or_exception from ..pool import TaskPool, SimpleTaskPool from .parser import ControlParser @@ -157,7 +157,7 @@ class ControlSession: log.debug("%s got an ArgumentError", self._client_class_name) self._writer.write(str(e).encode()) return - except HelpRequested: + except (HelpRequested, ParserError): log.debug("%s received usage help", self._client_class_name) return command = kwargs.pop(CMD) diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index c6858f8..c911e08 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -67,5 +67,9 @@ class HelpRequested(ServerException): pass +class ParserError(ServerException): + pass + + class CommandError(ServerException): pass diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py index 8f44d37..591cb93 100644 --- a/src/asyncio_taskpool/helpers.py +++ b/src/asyncio_taskpool/helpers.py @@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w If not, see .""" __doc__ = """ -Miscellaneous helper functions. +Miscellaneous helper functions. None of these should be considered part of the public API. """ diff --git a/src/asyncio_taskpool/queue_context.py b/src/asyncio_taskpool/queue_context.py index 29959bf..9870b92 100644 --- a/src/asyncio_taskpool/queue_context.py +++ b/src/asyncio_taskpool/queue_context.py @@ -53,6 +53,6 @@ class Queue(_Queue): Implements an asynchronous context manager for the queue. Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want, - but in some situations it makes the codes much cleaner. + but in some situations it makes the code much cleaner. """ self.item_processed() diff --git a/tests/test_control/test_client.py b/tests/test_control/test_client.py index 803194f..bcc844c 100644 --- a/tests/test_control/test_client.py +++ b/tests/test_control/test_client.py @@ -55,7 +55,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): def test_client_info(self): self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}, - client.ControlClient.client_info()) + client.ControlClient._client_info()) async def test_abstract(self): with self.assertRaises(NotImplementedError): @@ -65,12 +65,12 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): self.assertEqual(self.kwargs, self.client._conn_kwargs) self.assertFalse(self.client._connected) - @patch.object(client.ControlClient, 'client_info') - async def test__server_handshake(self, mock_client_info: MagicMock): - mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999} + @patch.object(client.ControlClient, '_client_info') + async def test__server_handshake(self, mock__client_info: MagicMock): + mock__client_info.return_value = mock_info = {FOO: 1, BAR: 9999} self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer)) self.assertTrue(self.client._connected) - mock_client_info.assert_called_once_with() + mock__client_info.assert_called_once_with() self.mock_write.assert_called_once_with(json.dumps(mock_info).encode()) self.mock_drain.assert_awaited_once_with() self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) diff --git a/tests/test_control/test_parser.py b/tests/test_control/test_parser.py index 2cd0c57..e7008eb 100644 --- a/tests/test_control/test_parser.py +++ b/tests/test_control/test_parser.py @@ -25,7 +25,7 @@ from unittest import TestCase from unittest.mock import MagicMock, call, patch from asyncio_taskpool.control import parser -from asyncio_taskpool.exceptions import HelpRequested +from asyncio_taskpool.exceptions import HelpRequested, ParserError FOO, BAR = 'foo', 'bar' @@ -41,7 +41,7 @@ class ControlServerTestCase(TestCase): self.kwargs = { 'stream_writer': self.stream_writer, 'terminal_width': self.terminal_width, - parser.FORMATTER_CLASS: FOO + 'formatter_class': FOO } self.parser = parser.ControlParser(**self.kwargs) @@ -183,7 +183,7 @@ class ControlServerTestCase(TestCase): @patch.object(parser.ArgumentParser, 'error') def test_error(self, mock_supercls_error: MagicMock): - with self.assertRaises(HelpRequested): + with self.assertRaises(ParserError): self.parser.error(FOO + BAR) mock_supercls_error.assert_called_once_with(message=FOO + BAR) @@ -194,11 +194,11 @@ class ControlServerTestCase(TestCase): self.parser.print_help(arg) mock_print_help.assert_called_once_with(arg) - @patch.object(parser, 'get_arg_type_wrapper') + @patch.object(parser, '_get_arg_type_wrapper') @patch.object(parser.ArgumentParser, 'add_argument') - def test_add_function_arg(self, mock_add_argument: MagicMock, mock_get_arg_type_wrapper: MagicMock): + def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_arg_type_wrapper: MagicMock): mock_add_argument.return_value = expected_output = 'action' - mock_get_arg_type_wrapper.return_value = mock_type = 'fake' + mock__get_arg_type_wrapper.return_value = mock_type = 'fake' foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex bar_default, baz_default, boo_default = 1, 0.1, 1j @@ -211,42 +211,42 @@ class ControlServerTestCase(TestCase): kwargs = {FOO + BAR: 'xyz'} self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs)) mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs) - mock_get_arg_type_wrapper.assert_called_once_with(foo_type) + mock__get_arg_type_wrapper.assert_called_once_with(foo_type) mock_add_argument.reset_mock() - mock_get_arg_type_wrapper.reset_mock() + mock__get_arg_type_wrapper.reset_mock() self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs)) mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs) - mock_get_arg_type_wrapper.assert_called_once_with(args_type) + mock__get_arg_type_wrapper.assert_called_once_with(args_type) mock_add_argument.reset_mock() - mock_get_arg_type_wrapper.reset_mock() + mock__get_arg_type_wrapper.reset_mock() self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs)) mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs) - mock_get_arg_type_wrapper.assert_called_once_with(bar_type) + mock__get_arg_type_wrapper.assert_called_once_with(bar_type) mock_add_argument.reset_mock() - mock_get_arg_type_wrapper.reset_mock() + mock__get_arg_type_wrapper.reset_mock() self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs)) mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs) - mock_get_arg_type_wrapper.assert_called_once_with(baz_type) + mock__get_arg_type_wrapper.assert_called_once_with(baz_type) mock_add_argument.reset_mock() - mock_get_arg_type_wrapper.reset_mock() + mock__get_arg_type_wrapper.reset_mock() self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs)) mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs) - mock_get_arg_type_wrapper.assert_called_once_with(boo_type) + mock__get_arg_type_wrapper.assert_called_once_with(boo_type) mock_add_argument.reset_mock() - mock_get_arg_type_wrapper.reset_mock() + mock__get_arg_type_wrapper.reset_mock() self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs)) mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs) - mock_get_arg_type_wrapper.assert_not_called() + mock__get_arg_type_wrapper.assert_not_called() @patch.object(parser.ControlParser, 'add_function_arg') def test_add_function_args(self, mock_add_function_arg: MagicMock): @@ -261,7 +261,8 @@ class ControlServerTestCase(TestCase): class RestTestCase(TestCase): - def test_get_arg_type_wrapper(self): - type_wrap = parser.get_arg_type_wrapper(int) + def test__get_arg_type_wrapper(self): + type_wrap = parser._get_arg_type_wrapper(int) + self.assertEqual('int', type_wrap.__name__) self.assertEqual(SUPPRESS, type_wrap(SUPPRESS)) self.assertEqual(13, type_wrap('13'))