major refactoring of the control session/parser classes; restructured constants

This commit is contained in:
Daniil Fajnberg 2022-02-13 19:39:21 +01:00
parent 3fb451a00e
commit 538b9cc91c
5 changed files with 137 additions and 126 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.3.1 version = 0.3.2
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -26,8 +26,8 @@ from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path from pathlib import Path
from asyncio_taskpool import constants from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from asyncio_taskpool.types import ClientConnT from .types import ClientConnT
class ControlClient(ABC): class ControlClient(ABC):
@ -38,7 +38,7 @@ class ControlClient(ABC):
@staticmethod @staticmethod
def client_info() -> dict: def client_info() -> dict:
return {'width': shutil.get_terminal_size().columns} return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None: def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs self._conn_kwargs = conn_kwargs
@ -48,17 +48,17 @@ class ControlClient(ABC):
self._connected = True self._connected = True
writer.write(json.dumps(self.client_info()).encode()) writer.write(json.dumps(self.client_info()).encode())
await writer.drain() await writer.drain()
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try: try:
msg = input("> ").strip().lower() msg = input("> ").strip().lower()
except EOFError: except EOFError:
msg = constants.CLIENT_EXIT msg = CLIENT_EXIT
except KeyboardInterrupt: except KeyboardInterrupt:
print() print()
return return
if msg == constants.CLIENT_EXIT: if msg == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
@ -69,7 +69,7 @@ class ControlClient(ABC):
self._connected = False self._connected = False
print(e, file=sys.stderr) print(e, file=sys.stderr)
return return
print((await reader.read(constants.MSG_BYTES)).decode()) print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self): async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs) reader, writer = await self.open_connection(**self._conn_kwargs)

View File

@ -20,13 +20,25 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024000
CMD = 'command'
CMD_NAME = 'name'
CMD_POOL_SIZE = 'pool-size'
CMD_NUM_RUNNING = 'num-running'
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop-all'
CMD_FUNC_NAME = 'func-name'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
SESSION_PARSER_WRITER = 'session_writer'
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'
class CMD:
__slots__ = ()
CMD = 'command'
NAME = 'name'
POOL_SIZE = 'pool-size'
NUM_RUNNING = 'num-running'
START = 'start'
STOP = 'stop'
STOP_ALL = 'stop-all'
FUNC_NAME = 'func-name'

View File

@ -1,13 +1,14 @@
import logging import logging
import json import json
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace from argparse import ArgumentError, HelpFormatter, Namespace
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from . import constants from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested from .exceptions import HelpRequested
from .helpers import get_first_doc_line, return_or_exception, tasks_str from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import ControlServer from .server import ControlServer
@ -15,47 +16,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
NUM = 'num'
WIDTH = 'width'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
class ClientHelpFormatter(HelpFormatter):
def __init__(self, *args, **kwargs) -> None:
kwargs[WIDTH] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
class ControlSession: class ControlSession:
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
@ -65,86 +25,65 @@ class ControlSession:
self._reader: StreamReader = reader self._reader: StreamReader = reader
self._writer: StreamWriter = writer self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None self._parser: Optional[CommandParser] = None
self._subparsers = None
def _add_base_commands(self): def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD) **kwargs) -> CommandParser:
subparsers.add_parser( if prog is None:
constants.CMD_NAME, prog = name
prog=constants.CMD_NAME, kwargs.setdefault('help', short_help or long_help)
help=get_first_doc_line(self._pool.__class__.__str__), kwargs.setdefault('description', long_help or short_help)
parent=self._parser, return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
)
subparser_pool_size = subparsers.add_parser( def _add_base_commands(self) -> None:
constants.CMD_POOL_SIZE, self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
prog=constants.CMD_POOL_SIZE, self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
help="Get/set the maximum number of tasks in the pool", self._add_parser_command(
parent=self._parser, CMD.POOL_SIZE,
) short_help="Get/set the maximum number of tasks in the pool.",
subparser_pool_size.add_argument( formatter_class=HelpFormatter
NUM, ).add_optional_num_argument(
nargs='?', default=None,
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
) )
subparsers.add_parser( self._add_parser_command(
constants.CMD_NUM_RUNNING, CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)
help=get_first_doc_line(self._pool.__class__.num_running.fget),
parent=self._parser,
) )
return subparsers
def _add_simple_commands(self): def _add_simple_commands(self) -> None:
subparsers = self._add_base_commands() self._add_parser_command(
subparser = subparsers.add_parser( CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
constants.CMD_START, ).add_optional_num_argument(
prog=constants.CMD_START, help="Number of tasks to start."
help=get_first_doc_line(self._pool.__class__.start),
parent=self._parser,
) )
subparser.add_argument( self._add_parser_command(
NUM, CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
nargs='?', ).add_optional_num_argument(
type=int, help="Number of tasks to stop."
default=1,
help="Number of tasks to start. Defaults to 1."
) )
subparser = subparsers.add_parser( self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
constants.CMD_STOP, self._add_parser_command(
prog=constants.CMD_STOP, CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
help=get_first_doc_line(self._pool.__class__.stop),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to stop. Defaults to 1."
)
subparsers.add_parser(
constants.CMD_STOP_ALL,
prog=constants.CMD_STOP_ALL,
help=get_first_doc_line(self._pool.__class__.stop_all),
parent=self._parser,
)
subparsers.add_parser(
constants.CMD_FUNC_NAME,
prog=constants.CMD_FUNC_NAME,
help=get_first_doc_line(self._pool.__class__.func_name.fget),
parent=self._parser,
) )
def _init_parser(self, client_terminal_width: int) -> None: def _init_parser(self, client_terminal_width: int) -> None:
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width) parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
self._add_base_commands()
if isinstance(self._pool, TaskPool): if isinstance(self._pool, TaskPool):
pass # TODO pass # TODO
elif isinstance(self._pool, SimpleTaskPool): elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands() self._add_simple_commands()
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[WIDTH]) self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
self._writer.write(str(self._pool).encode()) self._writer.write(str(self._pool).encode())
await self._writer.drain() await self._writer.drain()
@ -163,7 +102,7 @@ class ControlSession:
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else: else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num) log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num)) await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None: async def _cmd_num_running(self, **_kwargs) -> None:
log.debug("%s requests number of running tasks", self._client_class_name) log.debug("%s requests number of running tasks", self._client_class_name)
@ -189,7 +128,7 @@ class ControlSession:
async def _execute_command(self, args: Namespace) -> None: async def _execute_command(self, args: Namespace) -> None:
args = vars(args) args = vars(args)
cmd: str = args.pop(constants.CMD, None) cmd: str = args.pop(CMD.CMD, None)
if cmd is not None: if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
await method(**args) await method(**args)
@ -210,7 +149,7 @@ class ControlSession:
async def listen(self) -> None: async def listen(self) -> None:
while self._control_server.is_serving(): while self._control_server.is_serving():
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip() msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg: if not msg:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break

View File

@ -0,0 +1,60 @@
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter
from typing import Type, TypeVar
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
from .exceptions import HelpRequested
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
FORMATTER_CLASS = 'formatter_class'
NUM = 'num'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
if not name_or_flags:
name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?')
kwargs.setdefault('default', 1)
kwargs.setdefault('type', int)
return self.add_argument(*name_or_flags, **kwargs)