From be03097bf4ddef206cf0b2d703d46ca128dc246a Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sat, 12 Feb 2022 22:51:52 +0100 Subject: [PATCH] massive overhaul of the control server interface; making use of `ArgumentParser` for client commands; new `ControlSession` object instantiated upon connection --- setup.cfg | 2 +- src/asyncio_taskpool/client.py | 15 +- src/asyncio_taskpool/constants.py | 11 +- src/asyncio_taskpool/exceptions.py | 8 ++ src/asyncio_taskpool/helpers.py | 22 ++- src/asyncio_taskpool/pool.py | 1 + src/asyncio_taskpool/server.py | 91 +++--------- src/asyncio_taskpool/session.py | 218 +++++++++++++++++++++++++++++ src/asyncio_taskpool/types.py | 2 +- tests/test_helpers.py | 8 ++ 10 files changed, 300 insertions(+), 78 deletions(-) create mode 100644 src/asyncio_taskpool/session.py diff --git a/setup.cfg b/setup.cfg index 965f410..1bdeaca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.2.1 +version = 0.3.0 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/client.py b/src/asyncio_taskpool/client.py index 5c71f59..bab6b59 100644 --- a/src/asyncio_taskpool/client.py +++ b/src/asyncio_taskpool/client.py @@ -19,6 +19,8 @@ Classes of control clients for a simply interface to a task pool control server. """ +import json +import shutil import sys from abc import ABC, abstractmethod from asyncio.streams import StreamReader, StreamWriter, open_unix_connection @@ -34,10 +36,20 @@ class ControlClient(ABC): async def open_connection(self, **kwargs) -> ClientConnT: raise NotImplementedError + @staticmethod + def client_info() -> dict: + return {'width': shutil.get_terminal_size().columns} + def __init__(self, **conn_kwargs) -> None: self._conn_kwargs = conn_kwargs self._connected: bool = False + async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None: + self._connected = True + writer.write(json.dumps(self.client_info()).encode()) + await writer.drain() + print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) + async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: try: msg = input("> ").strip().lower() @@ -64,8 +76,7 @@ class ControlClient(ABC): if reader is None: print("Failed to connect.", file=sys.stderr) return - self._connected = True - print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) + await self._server_handshake(reader, writer) while self._connected: await self._interact(reader, writer) print("Disconnected from control server.") diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index 4812cda..785384c 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -20,10 +20,13 @@ Constants used by more than one module in the package. PACKAGE_NAME = 'asyncio_taskpool' -MSG_BYTES = 1024 +MSG_BYTES = 1024000 +CMD = 'command' +CMD_NAME = 'name' +CMD_POOL_SIZE = 'pool-size' +CMD_NUM_RUNNING = 'num-running' CMD_START = 'start' CMD_STOP = 'stop' -CMD_STOP_ALL = 'stop_all' -CMD_NUM_RUNNING = 'num_running' -CMD_FUNC = 'func' +CMD_STOP_ALL = 'stop-all' +CMD_FUNC_NAME = 'func-name' CLIENT_EXIT = 'exit' diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index f9772d9..e6e31cc 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -49,3 +49,11 @@ class PoolStillUnlocked(PoolException): class NotCoroutine(PoolException): pass + + +class ServerException(Exception): + pass + + +class HelpRequested(ServerException): + pass diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py index 5bc4168..4bc2ffb 100644 --- a/src/asyncio_taskpool/helpers.py +++ b/src/asyncio_taskpool/helpers.py @@ -19,9 +19,11 @@ Miscellaneous helper functions. """ +import re from asyncio.coroutines import iscoroutinefunction from asyncio.queues import Queue -from typing import Any, Optional +from inspect import getdoc +from typing import Any, Optional, Union from .types import T, AnyCallableT, ArgsT, KwArgsT @@ -48,3 +50,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: async def join_queue(q: Queue) -> None: await q.join() + + +def tasks_str(num: int) -> str: + return "tasks" if num != 1 else "task" + + +def get_first_doc_line(obj: object) -> str: + return getdoc(obj).strip().split("\n", 1)[0] + + +async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]: + try: + if iscoroutinefunction(_function_to_execute): + return await _function_to_execute(*args, **kwargs) + else: + return _function_to_execute(*args, **kwargs) + except Exception as e: + return e diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 9dcf91c..81d579e 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -78,6 +78,7 @@ class BaseTaskPool: log.debug("%s initialized", str(self)) def __str__(self) -> str: + """Returns the name of the task pool.""" return f'{self.__class__.__name__}-{self._name or self._idx}' @property diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 77780c2..3a3164e 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -26,32 +26,23 @@ from asyncio.exceptions import CancelledError from asyncio.streams import StreamReader, StreamWriter, start_unix_server from asyncio.tasks import Task, create_task from pathlib import Path -from typing import Tuple, Union, Optional +from typing import Optional, Union -from . import constants -from .pool import SimpleTaskPool from .client import ControlClient, UnixControlClient +from .pool import TaskPool, SimpleTaskPool +from .session import ControlSession log = logging.getLogger(__name__) -def tasks_str(num: int) -> str: - return "tasks" if num != 1 else "task" - - -def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]: - cmd = msg.strip().split(' ', 1) - if len(cmd) > 1: - try: - return cmd[0], int(cmd[1]) - except ValueError: - return None, None - return cmd[0], None - - class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool - client_class = ControlClient + _client_class = ControlClient + + @classmethod + @property + def client_class_name(cls) -> str: + return cls._client_class.__name__ @abstractmethod async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: @@ -61,63 +52,25 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta def final_callback(self) -> None: raise NotImplementedError - def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: - self._pool: SimpleTaskPool = pool + def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: + self._pool: Union[TaskPool, SimpleTaskPool] = pool self._server_kwargs = server_kwargs self._server: Optional[AbstractServer] = None - async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: - if num is None: - num = 1 - log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) - writer.write(str(await self._pool.start(num)).encode()) + def __str__(self) -> str: + return f"{self.__class__.__name__} for {self._pool}" - def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: - if num is None: - num = 1 - log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num)) - # the requested number may be greater than the total number of running tasks - writer.write(str(self._pool.stop(num)).encode()) + @property + def pool(self) -> Union[TaskPool, SimpleTaskPool]: + return self._pool - def _stop_all_tasks(self, writer: StreamWriter) -> None: - log.debug("%s requests stopping all tasks", self.client_class.__name__) - writer.write(str(self._pool.stop_all()).encode()) - - def _pool_size(self, writer: StreamWriter) -> None: - log.debug("%s requests number of running tasks", self.client_class.__name__) - writer.write(str(self._pool.num_running).encode()) - - def _pool_func(self, writer: StreamWriter) -> None: - log.debug("%s requests pool function", self.client_class.__name__) - writer.write(self._pool.func_name.encode()) - - async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None: - while self._server.is_serving(): - msg = (await reader.read(constants.MSG_BYTES)).decode().strip() - if not msg: - log.debug("%s disconnected", self.client_class.__name__) - break - cmd, arg = get_cmd_arg(msg) - if cmd == constants.CMD_START: - await self._start_tasks(writer, arg) - elif cmd == constants.CMD_STOP: - self._stop_tasks(writer, arg) - elif cmd == constants.CMD_STOP_ALL: - self._stop_all_tasks(writer) - elif cmd == constants.CMD_NUM_RUNNING: - self._pool_size(writer) - elif cmd == constants.CMD_FUNC: - self._pool_func(writer) - else: - log.debug("%s sent invalid command: %s", self.client_class.__name__, msg) - writer.write(b"Invalid command!") - await writer.drain() + def is_serving(self) -> bool: + return self._server.is_serving() async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: - log.debug("%s connected", self.client_class.__name__) - writer.write(str(self._pool).encode()) - await writer.drain() - await self._listen(reader, writer) + session = ControlSession(self, reader, writer) + await session.client_handshake() + await session.listen() async def _serve_forever(self) -> None: try: @@ -135,7 +88,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta class UnixControlServer(ControlServer): - client_class = UnixControlClient + _client_class = UnixControlClient def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: self._socket_path = Path(server_kwargs.pop('path')) diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py new file mode 100644 index 0000000..f13ff3c --- /dev/null +++ b/src/asyncio_taskpool/session.py @@ -0,0 +1,218 @@ +import logging +import json +from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace +from asyncio.streams import StreamReader, StreamWriter +from typing import Callable, Optional, Type, Union, TYPE_CHECKING + +from . import constants +from .exceptions import HelpRequested +from .helpers import get_first_doc_line, return_or_exception, tasks_str +from .pool import TaskPool, SimpleTaskPool + +if TYPE_CHECKING: + from .server import ControlServer + + +log = logging.getLogger(__name__) + +NUM = 'num' +WIDTH = 'width' + + +class CommandParser(ArgumentParser): + @staticmethod + def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]: + class ClientHelpFormatter(HelpFormatter): + def __init__(self, *args, **kwargs) -> None: + kwargs[WIDTH] = terminal_width + super().__init__(*args, **kwargs) + return ClientHelpFormatter + + def __init__(self, *args, **kwargs) -> None: + parent: CommandParser = kwargs.pop('parent', None) + self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer') + self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH) + kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width)) + kwargs.setdefault('exit_on_error', False) + super().__init__(*args, **kwargs) + + @property + def stream_writer(self) -> StreamWriter: + return self._stream_writer + + @property + def terminal_width(self) -> int: + return self._terminal_width + + def _print_message(self, message: str, *args, **kwargs) -> None: + if message: + self.stream_writer.write(message.encode()) + + def exit(self, status: int = 0, message: str = None) -> None: + if message: + self._print_message(message) + + def print_help(self, file=None) -> None: + super().print_help(file) + raise HelpRequested + + +class ControlSession: + def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: + self._control_server: 'ControlServer' = server + self._pool: Union[TaskPool, SimpleTaskPool] = server.pool + self._client_class_name = server.client_class_name + self._reader: StreamReader = reader + self._writer: StreamWriter = writer + self._parser: Optional[CommandParser] = None + + def _add_base_commands(self): + subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD) + subparsers.add_parser( + constants.CMD_NAME, + prog=constants.CMD_NAME, + help=get_first_doc_line(self._pool.__class__.__str__), + parent=self._parser, + ) + subparser_pool_size = subparsers.add_parser( + constants.CMD_POOL_SIZE, + prog=constants.CMD_POOL_SIZE, + help="Get/set the maximum number of tasks in the pool", + parent=self._parser, + ) + subparser_pool_size.add_argument( + NUM, + nargs='?', + help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " + f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" + ) + subparsers.add_parser( + constants.CMD_NUM_RUNNING, + help=get_first_doc_line(self._pool.__class__.num_running.fget), + parent=self._parser, + ) + return subparsers + + def _add_simple_commands(self): + subparsers = self._add_base_commands() + subparser = subparsers.add_parser( + constants.CMD_START, + prog=constants.CMD_START, + help=get_first_doc_line(self._pool.__class__.start), + parent=self._parser, + ) + subparser.add_argument( + NUM, + nargs='?', + type=int, + default=1, + help="Number of tasks to start. Defaults to 1." + ) + subparser = subparsers.add_parser( + constants.CMD_STOP, + prog=constants.CMD_STOP, + help=get_first_doc_line(self._pool.__class__.stop), + parent=self._parser, + ) + subparser.add_argument( + NUM, + nargs='?', + type=int, + default=1, + help="Number of tasks to stop. Defaults to 1." + ) + subparsers.add_parser( + constants.CMD_STOP_ALL, + prog=constants.CMD_STOP_ALL, + help=get_first_doc_line(self._pool.__class__.stop_all), + parent=self._parser, + ) + subparsers.add_parser( + constants.CMD_FUNC_NAME, + prog=constants.CMD_FUNC_NAME, + help=get_first_doc_line(self._pool.__class__.func_name.fget), + parent=self._parser, + ) + + def _init_parser(self, client_terminal_width: int) -> None: + self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width) + if isinstance(self._pool, TaskPool): + pass # TODO + elif isinstance(self._pool, SimpleTaskPool): + self._add_simple_commands() + + async def client_handshake(self) -> None: + client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip()) + log.debug("%s connected", self._client_class_name) + self._init_parser(client_info[WIDTH]) + self._writer.write(str(self._pool).encode()) + await self._writer.drain() + + async def _write_function_output(self, func: Callable, *args, **kwargs) -> None: + output = await return_or_exception(func, *args, **kwargs) + self._writer.write(b"ok" if output is None else str(output).encode()) + + async def _cmd_name(self, **_kwargs) -> None: + log.debug("%s requests task pool name", self._client_class_name) + await self._write_function_output(self._pool.__class__.__str__, self._pool) + + async def _cmd_pool_size(self, **kwargs) -> None: + num = kwargs.get(NUM) + if num is None: + log.debug("%s requests pool size", self._client_class_name) + await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) + else: + log.debug("%s requests setting pool size to %s", self._client_class_name, num) + await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num)) + + async def _cmd_num_running(self, **_kwargs) -> None: + log.debug("%s requests number of running tasks", self._client_class_name) + await self._write_function_output(self._pool.__class__.num_running.fget, self._pool) + + async def _cmd_start(self, **kwargs) -> None: + num = kwargs[NUM] + log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num)) + await self._write_function_output(self._pool.start, num) + + async def _cmd_stop(self, **kwargs) -> None: + num = kwargs[NUM] + log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num)) + await self._write_function_output(self._pool.stop, num) + + async def _cmd_stop_all(self, **_kwargs) -> None: + log.debug("%s requests stopping all tasks", self._client_class_name) + await self._write_function_output(self._pool.stop_all) + + async def _cmd_func_name(self, **_kwargs) -> None: + log.debug("%s requests pool function name", self._client_class_name) + await self._write_function_output(self._pool.__class__.func_name.fget, self._pool) + + async def _execute_command(self, args: Namespace) -> None: + args = vars(args) + cmd: str = args.pop(constants.CMD, None) + if cmd is not None: + method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') + await method(**args) + + async def _parse_command(self, msg: str) -> None: + try: + args, argv = self._parser.parse_known_args(msg.split(' ')) + except ArgumentError as e: + self._writer.write(str(e).encode()) + return + except HelpRequested: + return + if argv: + log.debug("%s sent unknown arguments: %s", self._client_class_name, msg) + self._writer.write(b"Invalid command!") + return + await self._execute_command(args) + + async def listen(self) -> None: + while self._control_server.is_serving(): + msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip() + if not msg: + log.debug("%s disconnected", self._client_class_name) + break + await self._parse_command(msg) + await self._writer.drain() diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index ac4a871..bca1c33 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -28,7 +28,7 @@ T = TypeVar('T') ArgsT = Iterable[Any] KwArgsT = Mapping[str, Any] -AnyCallableT = Callable[[...], Union[Awaitable[T], T]] +AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] CoroutineFunc = Callable[[...], Awaitable[Any]] EndCallbackT = Callable diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5dbf657..9c3e551 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -86,3 +86,11 @@ class HelpersTestCase(IsolatedAsyncioTestCase): mock_queue = MagicMock(join=mock_join) self.assertIsNone(await helpers.join_queue(mock_queue)) mock_join.assert_awaited_once_with() + + def test_task_str(self): + self.assertEqual("task", helpers.tasks_str(1)) + self.assertEqual("tasks", helpers.tasks_str(0)) + self.assertEqual("tasks", helpers.tasks_str(-1)) + self.assertEqual("tasks", helpers.tasks_str(2)) + self.assertEqual("tasks", helpers.tasks_str(-10)) + self.assertEqual("tasks", helpers.tasks_str(42))