massive overhaul of the control server interface;

making use of `ArgumentParser` for client commands;
new `ControlSession` object instantiated upon connection
This commit is contained in:
Daniil Fajnberg 2022-02-12 22:51:52 +01:00
parent 024e5db0d4
commit be03097bf4
10 changed files with 300 additions and 78 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.2.1 version = 0.3.0
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -19,6 +19,8 @@ Classes of control clients for a simply interface to a task pool control server.
""" """
import json
import shutil
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
@ -34,10 +36,20 @@ class ControlClient(ABC):
async def open_connection(self, **kwargs) -> ClientConnT: async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError raise NotImplementedError
@staticmethod
def client_info() -> dict:
return {'width': shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None: def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs self._conn_kwargs = conn_kwargs
self._connected: bool = False self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try: try:
msg = input("> ").strip().lower() msg = input("> ").strip().lower()
@ -64,8 +76,7 @@ class ControlClient(ABC):
if reader is None: if reader is None:
print("Failed to connect.", file=sys.stderr) print("Failed to connect.", file=sys.stderr)
return return
self._connected = True await self._server_handshake(reader, writer)
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
while self._connected: while self._connected:
await self._interact(reader, writer) await self._interact(reader, writer)
print("Disconnected from control server.") print("Disconnected from control server.")

View File

@ -20,10 +20,13 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024 MSG_BYTES = 1024000
CMD = 'command'
CMD_NAME = 'name'
CMD_POOL_SIZE = 'pool-size'
CMD_NUM_RUNNING = 'num-running'
CMD_START = 'start' CMD_START = 'start'
CMD_STOP = 'stop' CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all' CMD_STOP_ALL = 'stop-all'
CMD_NUM_RUNNING = 'num_running' CMD_FUNC_NAME = 'func-name'
CMD_FUNC = 'func'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'

View File

@ -49,3 +49,11 @@ class PoolStillUnlocked(PoolException):
class NotCoroutine(PoolException): class NotCoroutine(PoolException):
pass pass
class ServerException(Exception):
pass
class HelpRequested(ServerException):
pass

View File

@ -19,9 +19,11 @@ Miscellaneous helper functions.
""" """
import re
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue from asyncio.queues import Queue
from typing import Any, Optional from inspect import getdoc
from typing import Any, Optional, Union
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -48,3 +50,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
async def join_queue(q: Queue) -> None: async def join_queue(q: Queue) -> None:
await q.join() await q.join()
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0]
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
try:
if iscoroutinefunction(_function_to_execute):
return await _function_to_execute(*args, **kwargs)
else:
return _function_to_execute(*args, **kwargs)
except Exception as e:
return e

View File

@ -78,6 +78,7 @@ class BaseTaskPool:
log.debug("%s initialized", str(self)) log.debug("%s initialized", str(self))
def __str__(self) -> str: def __str__(self) -> str:
"""Returns the name of the task pool."""
return f'{self.__class__.__name__}-{self._name or self._idx}' return f'{self.__class__.__name__}-{self._name or self._idx}'
@property @property

View File

@ -26,32 +26,23 @@ from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from pathlib import Path from pathlib import Path
from typing import Tuple, Union, Optional from typing import Optional, Union
from . import constants
from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient from .client import ControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
cmd = msg.strip().split(' ', 1)
if len(cmd) > 1:
try:
return cmd[0], int(cmd[1])
except ValueError:
return None, None
return cmd[0], None
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient _client_class = ControlClient
@classmethod
@property
def client_class_name(cls) -> str:
return cls._client_class.__name__
@abstractmethod @abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
@ -61,63 +52,25 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
def final_callback(self) -> None: def final_callback(self) -> None:
raise NotImplementedError raise NotImplementedError
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
self._pool: SimpleTaskPool = pool self._pool: Union[TaskPool, SimpleTaskPool] = pool
self._server_kwargs = server_kwargs self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None self._server: Optional[AbstractServer] = None
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: def __str__(self) -> str:
if num is None: return f"{self.__class__.__name__} for {self._pool}"
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(await self._pool.start(num)).encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: @property
if num is None: def pool(self) -> Union[TaskPool, SimpleTaskPool]:
num = 1 return self._pool
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
# the requested number may be greater than the total number of running tasks
writer.write(str(self._pool.stop(num)).encode())
def _stop_all_tasks(self, writer: StreamWriter) -> None: def is_serving(self) -> bool:
log.debug("%s requests stopping all tasks", self.client_class.__name__) return self._server.is_serving()
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests number of running tasks", self.client_class.__name__)
writer.write(str(self._pool.num_running).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
writer.write(self._pool.func_name.encode())
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
while self._server.is_serving():
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self.client_class.__name__)
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_NUM_RUNNING:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)
else:
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
writer.write(b"Invalid command!")
await writer.drain()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__) session = ControlSession(self, reader, writer)
writer.write(str(self._pool).encode()) await session.client_handshake()
await writer.drain() await session.listen()
await self._listen(reader, writer)
async def _serve_forever(self) -> None: async def _serve_forever(self) -> None:
try: try:
@ -135,7 +88,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
class UnixControlServer(ControlServer): class UnixControlServer(ControlServer):
client_class = UnixControlClient _client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path')) self._socket_path = Path(server_kwargs.pop('path'))

View File

@ -0,0 +1,218 @@
import logging
import json
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from . import constants
from .exceptions import HelpRequested
from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool
if TYPE_CHECKING:
from .server import ControlServer
log = logging.getLogger(__name__)
NUM = 'num'
WIDTH = 'width'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
class ClientHelpFormatter(HelpFormatter):
def __init__(self, *args, **kwargs) -> None:
kwargs[WIDTH] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
class ControlSession:
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None
def _add_base_commands(self):
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD)
subparsers.add_parser(
constants.CMD_NAME,
prog=constants.CMD_NAME,
help=get_first_doc_line(self._pool.__class__.__str__),
parent=self._parser,
)
subparser_pool_size = subparsers.add_parser(
constants.CMD_POOL_SIZE,
prog=constants.CMD_POOL_SIZE,
help="Get/set the maximum number of tasks in the pool",
parent=self._parser,
)
subparser_pool_size.add_argument(
NUM,
nargs='?',
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
)
subparsers.add_parser(
constants.CMD_NUM_RUNNING,
help=get_first_doc_line(self._pool.__class__.num_running.fget),
parent=self._parser,
)
return subparsers
def _add_simple_commands(self):
subparsers = self._add_base_commands()
subparser = subparsers.add_parser(
constants.CMD_START,
prog=constants.CMD_START,
help=get_first_doc_line(self._pool.__class__.start),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to start. Defaults to 1."
)
subparser = subparsers.add_parser(
constants.CMD_STOP,
prog=constants.CMD_STOP,
help=get_first_doc_line(self._pool.__class__.stop),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to stop. Defaults to 1."
)
subparsers.add_parser(
constants.CMD_STOP_ALL,
prog=constants.CMD_STOP_ALL,
help=get_first_doc_line(self._pool.__class__.stop_all),
parent=self._parser,
)
subparsers.add_parser(
constants.CMD_FUNC_NAME,
prog=constants.CMD_FUNC_NAME,
help=get_first_doc_line(self._pool.__class__.func_name.fget),
parent=self._parser,
)
def _init_parser(self, client_terminal_width: int) -> None:
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width)
if isinstance(self._pool, TaskPool):
pass # TODO
elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands()
async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[WIDTH])
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num))
async def _cmd_num_running(self, **_kwargs) -> None:
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None:
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, args: Namespace) -> None:
args = vars(args)
cmd: str = args.pop(constants.CMD, None)
if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
await method(**args)
async def _parse_command(self, msg: str) -> None:
try:
args, argv = self._parser.parse_known_args(msg.split(' '))
except ArgumentError as e:
self._writer.write(str(e).encode())
return
except HelpRequested:
return
if argv:
log.debug("%s sent unknown arguments: %s", self._client_class_name, msg)
self._writer.write(b"Invalid command!")
return
await self._execute_command(args)
async def listen(self) -> None:
while self._control_server.is_serving():
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
await self._writer.drain()

View File

@ -28,7 +28,7 @@ T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]] AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCallbackT = Callable

View File

@ -86,3 +86,11 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
mock_queue = MagicMock(join=mock_join) mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue)) self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with() mock_join.assert_awaited_once_with()
def test_task_str(self):
self.assertEqual("task", helpers.tasks_str(1))
self.assertEqual("tasks", helpers.tasks_str(0))
self.assertEqual("tasks", helpers.tasks_str(-1))
self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42))