From 538b9cc91c35a9da8dd88f55ade812cc848d801c Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 13 Feb 2022 19:39:21 +0100 Subject: [PATCH] major refactoring of the control session/parser classes; restructured constants --- setup.cfg | 2 +- src/asyncio_taskpool/client.py | 14 +-- src/asyncio_taskpool/constants.py | 30 +++-- src/asyncio_taskpool/session.py | 157 ++++++++----------------- src/asyncio_taskpool/session_parser.py | 60 ++++++++++ 5 files changed, 137 insertions(+), 126 deletions(-) create mode 100644 src/asyncio_taskpool/session_parser.py diff --git a/setup.cfg b/setup.cfg index 8dae159..5c2e218 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.3.1 +version = 0.3.2 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/client.py b/src/asyncio_taskpool/client.py index bab6b59..0362ce5 100644 --- a/src/asyncio_taskpool/client.py +++ b/src/asyncio_taskpool/client.py @@ -26,8 +26,8 @@ from abc import ABC, abstractmethod from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from pathlib import Path -from asyncio_taskpool import constants -from asyncio_taskpool.types import ClientConnT +from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES +from .types import ClientConnT class ControlClient(ABC): @@ -38,7 +38,7 @@ class ControlClient(ABC): @staticmethod def client_info() -> dict: - return {'width': shutil.get_terminal_size().columns} + return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns} def __init__(self, **conn_kwargs) -> None: self._conn_kwargs = conn_kwargs @@ -48,17 +48,17 @@ class ControlClient(ABC): self._connected = True writer.write(json.dumps(self.client_info()).encode()) await writer.drain() - print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) + print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: try: msg = input("> ").strip().lower() except EOFError: - msg = constants.CLIENT_EXIT + msg = CLIENT_EXIT except KeyboardInterrupt: print() return - if msg == constants.CLIENT_EXIT: + if msg == CLIENT_EXIT: writer.close() self._connected = False return @@ -69,7 +69,7 @@ class ControlClient(ABC): self._connected = False print(e, file=sys.stderr) return - print((await reader.read(constants.MSG_BYTES)).decode()) + print((await reader.read(SESSION_MSG_BYTES)).decode()) async def start(self): reader, writer = await self.open_connection(**self._conn_kwargs) diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index 785384c..19d56c7 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -20,13 +20,25 @@ Constants used by more than one module in the package. PACKAGE_NAME = 'asyncio_taskpool' -MSG_BYTES = 1024000 -CMD = 'command' -CMD_NAME = 'name' -CMD_POOL_SIZE = 'pool-size' -CMD_NUM_RUNNING = 'num-running' -CMD_START = 'start' -CMD_STOP = 'stop' -CMD_STOP_ALL = 'stop-all' -CMD_FUNC_NAME = 'func-name' + CLIENT_EXIT = 'exit' + +SESSION_MSG_BYTES = 1024 * 100 +SESSION_PARSER_WRITER = 'session_writer' + + +class CLIENT_INFO: + __slots__ = () + TERMINAL_WIDTH = 'terminal_width' + + +class CMD: + __slots__ = () + CMD = 'command' + NAME = 'name' + POOL_SIZE = 'pool-size' + NUM_RUNNING = 'num-running' + START = 'start' + STOP = 'stop' + STOP_ALL = 'stop-all' + FUNC_NAME = 'func-name' diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py index f13ff3c..b4c069a 100644 --- a/src/asyncio_taskpool/session.py +++ b/src/asyncio_taskpool/session.py @@ -1,13 +1,14 @@ import logging import json -from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace +from argparse import ArgumentError, HelpFormatter, Namespace from asyncio.streams import StreamReader, StreamWriter -from typing import Callable, Optional, Type, Union, TYPE_CHECKING +from typing import Callable, Optional, Union, TYPE_CHECKING -from . import constants +from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO from .exceptions import HelpRequested from .helpers import get_first_doc_line, return_or_exception, tasks_str from .pool import TaskPool, SimpleTaskPool +from .session_parser import CommandParser, NUM if TYPE_CHECKING: from .server import ControlServer @@ -15,47 +16,6 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) -NUM = 'num' -WIDTH = 'width' - - -class CommandParser(ArgumentParser): - @staticmethod - def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]: - class ClientHelpFormatter(HelpFormatter): - def __init__(self, *args, **kwargs) -> None: - kwargs[WIDTH] = terminal_width - super().__init__(*args, **kwargs) - return ClientHelpFormatter - - def __init__(self, *args, **kwargs) -> None: - parent: CommandParser = kwargs.pop('parent', None) - self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer') - self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH) - kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width)) - kwargs.setdefault('exit_on_error', False) - super().__init__(*args, **kwargs) - - @property - def stream_writer(self) -> StreamWriter: - return self._stream_writer - - @property - def terminal_width(self) -> int: - return self._terminal_width - - def _print_message(self, message: str, *args, **kwargs) -> None: - if message: - self.stream_writer.write(message.encode()) - - def exit(self, status: int = 0, message: str = None) -> None: - if message: - self._print_message(message) - - def print_help(self, file=None) -> None: - super().print_help(file) - raise HelpRequested - class ControlSession: def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: @@ -65,86 +25,65 @@ class ControlSession: self._reader: StreamReader = reader self._writer: StreamWriter = writer self._parser: Optional[CommandParser] = None + self._subparsers = None - def _add_base_commands(self): - subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD) - subparsers.add_parser( - constants.CMD_NAME, - prog=constants.CMD_NAME, - help=get_first_doc_line(self._pool.__class__.__str__), - parent=self._parser, - ) - subparser_pool_size = subparsers.add_parser( - constants.CMD_POOL_SIZE, - prog=constants.CMD_POOL_SIZE, - help="Get/set the maximum number of tasks in the pool", - parent=self._parser, - ) - subparser_pool_size.add_argument( - NUM, - nargs='?', + def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, + **kwargs) -> CommandParser: + if prog is None: + prog = name + kwargs.setdefault('help', short_help or long_help) + kwargs.setdefault('description', long_help or short_help) + return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs) + + def _add_base_commands(self) -> None: + self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) + self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) + self._add_parser_command( + CMD.POOL_SIZE, + short_help="Get/set the maximum number of tasks in the pool.", + formatter_class=HelpFormatter + ).add_optional_num_argument( + default=None, help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" ) - subparsers.add_parser( - constants.CMD_NUM_RUNNING, - help=get_first_doc_line(self._pool.__class__.num_running.fget), - parent=self._parser, + self._add_parser_command( + CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget) ) - return subparsers - def _add_simple_commands(self): - subparsers = self._add_base_commands() - subparser = subparsers.add_parser( - constants.CMD_START, - prog=constants.CMD_START, - help=get_first_doc_line(self._pool.__class__.start), - parent=self._parser, + def _add_simple_commands(self) -> None: + self._add_parser_command( + CMD.START, short_help=get_first_doc_line(self._pool.__class__.start) + ).add_optional_num_argument( + help="Number of tasks to start." ) - subparser.add_argument( - NUM, - nargs='?', - type=int, - default=1, - help="Number of tasks to start. Defaults to 1." + self._add_parser_command( + CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop) + ).add_optional_num_argument( + help="Number of tasks to stop." ) - subparser = subparsers.add_parser( - constants.CMD_STOP, - prog=constants.CMD_STOP, - help=get_first_doc_line(self._pool.__class__.stop), - parent=self._parser, - ) - subparser.add_argument( - NUM, - nargs='?', - type=int, - default=1, - help="Number of tasks to stop. Defaults to 1." - ) - subparsers.add_parser( - constants.CMD_STOP_ALL, - prog=constants.CMD_STOP_ALL, - help=get_first_doc_line(self._pool.__class__.stop_all), - parent=self._parser, - ) - subparsers.add_parser( - constants.CMD_FUNC_NAME, - prog=constants.CMD_FUNC_NAME, - help=get_first_doc_line(self._pool.__class__.func_name.fget), - parent=self._parser, + self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) + self._add_parser_command( + CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget) ) def _init_parser(self, client_terminal_width: int) -> None: - self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width) + parser_kwargs = { + 'prog': '', + SESSION_PARSER_WRITER: self._writer, + CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, + } + self._parser = CommandParser(**parser_kwargs) + self._add_base_commands() if isinstance(self._pool, TaskPool): pass # TODO elif isinstance(self._pool, SimpleTaskPool): self._add_simple_commands() async def client_handshake(self) -> None: - client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip()) + client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) log.debug("%s connected", self._client_class_name) - self._init_parser(client_info[WIDTH]) + self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH]) self._writer.write(str(self._pool).encode()) await self._writer.drain() @@ -163,7 +102,7 @@ class ControlSession: await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) else: log.debug("%s requests setting pool size to %s", self._client_class_name, num) - await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num)) + await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num) async def _cmd_num_running(self, **_kwargs) -> None: log.debug("%s requests number of running tasks", self._client_class_name) @@ -189,7 +128,7 @@ class ControlSession: async def _execute_command(self, args: Namespace) -> None: args = vars(args) - cmd: str = args.pop(constants.CMD, None) + cmd: str = args.pop(CMD.CMD, None) if cmd is not None: method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') await method(**args) @@ -210,7 +149,7 @@ class ControlSession: async def listen(self) -> None: while self._control_server.is_serving(): - msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip() + msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() if not msg: log.debug("%s disconnected", self._client_class_name) break diff --git a/src/asyncio_taskpool/session_parser.py b/src/asyncio_taskpool/session_parser.py new file mode 100644 index 0000000..a315234 --- /dev/null +++ b/src/asyncio_taskpool/session_parser.py @@ -0,0 +1,60 @@ +from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter +from asyncio.streams import StreamWriter +from typing import Type, TypeVar + +from .constants import SESSION_PARSER_WRITER, CLIENT_INFO +from .exceptions import HelpRequested + + +FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) +FORMATTER_CLASS = 'formatter_class' +NUM = 'num' + + +class CommandParser(ArgumentParser): + @staticmethod + def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: + if base_cls is None: + base_cls = ArgumentDefaultsHelpFormatter + + class ClientHelpFormatter(base_cls): + def __init__(self, *args, **kwargs) -> None: + kwargs['width'] = terminal_width + super().__init__(*args, **kwargs) + return ClientHelpFormatter + + def __init__(self, *args, **kwargs) -> None: + parent: CommandParser = kwargs.pop('parent', None) + self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER) + self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH) + kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) + kwargs.setdefault('exit_on_error', False) + super().__init__(*args, **kwargs) + + @property + def stream_writer(self) -> StreamWriter: + return self._stream_writer + + @property + def terminal_width(self) -> int: + return self._terminal_width + + def _print_message(self, message: str, *args, **kwargs) -> None: + if message: + self.stream_writer.write(message.encode()) + + def exit(self, status: int = 0, message: str = None) -> None: + if message: + self._print_message(message) + + def print_help(self, file=None) -> None: + super().print_help(file) + raise HelpRequested + + def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action: + if not name_or_flags: + name_or_flags = (NUM, ) + kwargs.setdefault('nargs', '?') + kwargs.setdefault('default', 1) + kwargs.setdefault('type', int) + return self.add_argument(*name_or_flags, **kwargs)