generated from daniil-berg/boilerplate-py
massive overhaul of the control server interface;
making use of `ArgumentParser` for client commands; new `ControlSession` object instantiated upon connection
This commit is contained in:
parent
024e5db0d4
commit
be03097bf4
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 0.2.1
|
version = 0.3.0
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -19,6 +19,8 @@ Classes of control clients for a simply interface to a task pool control server.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||||
@ -34,10 +36,20 @@ class ControlClient(ABC):
|
|||||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def client_info() -> dict:
|
||||||
|
return {'width': shutil.get_terminal_size().columns}
|
||||||
|
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
self._conn_kwargs = conn_kwargs
|
self._conn_kwargs = conn_kwargs
|
||||||
self._connected: bool = False
|
self._connected: bool = False
|
||||||
|
|
||||||
|
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
self._connected = True
|
||||||
|
writer.write(json.dumps(self.client_info()).encode())
|
||||||
|
await writer.drain()
|
||||||
|
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
msg = input("> ").strip().lower()
|
||||||
@ -64,8 +76,7 @@ class ControlClient(ABC):
|
|||||||
if reader is None:
|
if reader is None:
|
||||||
print("Failed to connect.", file=sys.stderr)
|
print("Failed to connect.", file=sys.stderr)
|
||||||
return
|
return
|
||||||
self._connected = True
|
await self._server_handshake(reader, writer)
|
||||||
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
|
||||||
while self._connected:
|
while self._connected:
|
||||||
await self._interact(reader, writer)
|
await self._interact(reader, writer)
|
||||||
print("Disconnected from control server.")
|
print("Disconnected from control server.")
|
||||||
|
@ -20,10 +20,13 @@ Constants used by more than one module in the package.
|
|||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
MSG_BYTES = 1024
|
MSG_BYTES = 1024000
|
||||||
|
CMD = 'command'
|
||||||
|
CMD_NAME = 'name'
|
||||||
|
CMD_POOL_SIZE = 'pool-size'
|
||||||
|
CMD_NUM_RUNNING = 'num-running'
|
||||||
CMD_START = 'start'
|
CMD_START = 'start'
|
||||||
CMD_STOP = 'stop'
|
CMD_STOP = 'stop'
|
||||||
CMD_STOP_ALL = 'stop_all'
|
CMD_STOP_ALL = 'stop-all'
|
||||||
CMD_NUM_RUNNING = 'num_running'
|
CMD_FUNC_NAME = 'func-name'
|
||||||
CMD_FUNC = 'func'
|
|
||||||
CLIENT_EXIT = 'exit'
|
CLIENT_EXIT = 'exit'
|
||||||
|
@ -49,3 +49,11 @@ class PoolStillUnlocked(PoolException):
|
|||||||
|
|
||||||
class NotCoroutine(PoolException):
|
class NotCoroutine(PoolException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServerException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HelpRequested(ServerException):
|
||||||
|
pass
|
||||||
|
@ -19,9 +19,11 @@ Miscellaneous helper functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from asyncio.queues import Queue
|
from asyncio.queues import Queue
|
||||||
from typing import Any, Optional
|
from inspect import getdoc
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
@ -48,3 +50,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
|||||||
|
|
||||||
async def join_queue(q: Queue) -> None:
|
async def join_queue(q: Queue) -> None:
|
||||||
await q.join()
|
await q.join()
|
||||||
|
|
||||||
|
|
||||||
|
def tasks_str(num: int) -> str:
|
||||||
|
return "tasks" if num != 1 else "task"
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_doc_line(obj: object) -> str:
|
||||||
|
return getdoc(obj).strip().split("\n", 1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
|
||||||
|
try:
|
||||||
|
if iscoroutinefunction(_function_to_execute):
|
||||||
|
return await _function_to_execute(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return _function_to_execute(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
@ -78,6 +78,7 @@ class BaseTaskPool:
|
|||||||
log.debug("%s initialized", str(self))
|
log.debug("%s initialized", str(self))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
"""Returns the name of the task pool."""
|
||||||
return f'{self.__class__.__name__}-{self._name or self._idx}'
|
return f'{self.__class__.__name__}-{self._name or self._idx}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -26,32 +26,23 @@ from asyncio.exceptions import CancelledError
|
|||||||
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
||||||
from asyncio.tasks import Task, create_task
|
from asyncio.tasks import Task, create_task
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from . import constants
|
|
||||||
from .pool import SimpleTaskPool
|
|
||||||
from .client import ControlClient, UnixControlClient
|
from .client import ControlClient, UnixControlClient
|
||||||
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
|
from .session import ControlSession
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def tasks_str(num: int) -> str:
|
|
||||||
return "tasks" if num != 1 else "task"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
|
|
||||||
cmd = msg.strip().split(' ', 1)
|
|
||||||
if len(cmd) > 1:
|
|
||||||
try:
|
|
||||||
return cmd[0], int(cmd[1])
|
|
||||||
except ValueError:
|
|
||||||
return None, None
|
|
||||||
return cmd[0], None
|
|
||||||
|
|
||||||
|
|
||||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||||
client_class = ControlClient
|
_client_class = ControlClient
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def client_class_name(cls) -> str:
|
||||||
|
return cls._client_class.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||||
@ -61,63 +52,25 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
|||||||
def final_callback(self) -> None:
|
def final_callback(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||||
self._pool: SimpleTaskPool = pool
|
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
||||||
self._server_kwargs = server_kwargs
|
self._server_kwargs = server_kwargs
|
||||||
self._server: Optional[AbstractServer] = None
|
self._server: Optional[AbstractServer] = None
|
||||||
|
|
||||||
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
def __str__(self) -> str:
|
||||||
if num is None:
|
return f"{self.__class__.__name__} for {self._pool}"
|
||||||
num = 1
|
|
||||||
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
|
|
||||||
writer.write(str(await self._pool.start(num)).encode())
|
|
||||||
|
|
||||||
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
@property
|
||||||
if num is None:
|
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
||||||
num = 1
|
return self._pool
|
||||||
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
|
|
||||||
# the requested number may be greater than the total number of running tasks
|
|
||||||
writer.write(str(self._pool.stop(num)).encode())
|
|
||||||
|
|
||||||
def _stop_all_tasks(self, writer: StreamWriter) -> None:
|
def is_serving(self) -> bool:
|
||||||
log.debug("%s requests stopping all tasks", self.client_class.__name__)
|
return self._server.is_serving()
|
||||||
writer.write(str(self._pool.stop_all()).encode())
|
|
||||||
|
|
||||||
def _pool_size(self, writer: StreamWriter) -> None:
|
|
||||||
log.debug("%s requests number of running tasks", self.client_class.__name__)
|
|
||||||
writer.write(str(self._pool.num_running).encode())
|
|
||||||
|
|
||||||
def _pool_func(self, writer: StreamWriter) -> None:
|
|
||||||
log.debug("%s requests pool function", self.client_class.__name__)
|
|
||||||
writer.write(self._pool.func_name.encode())
|
|
||||||
|
|
||||||
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
|
|
||||||
while self._server.is_serving():
|
|
||||||
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
|
|
||||||
if not msg:
|
|
||||||
log.debug("%s disconnected", self.client_class.__name__)
|
|
||||||
break
|
|
||||||
cmd, arg = get_cmd_arg(msg)
|
|
||||||
if cmd == constants.CMD_START:
|
|
||||||
await self._start_tasks(writer, arg)
|
|
||||||
elif cmd == constants.CMD_STOP:
|
|
||||||
self._stop_tasks(writer, arg)
|
|
||||||
elif cmd == constants.CMD_STOP_ALL:
|
|
||||||
self._stop_all_tasks(writer)
|
|
||||||
elif cmd == constants.CMD_NUM_RUNNING:
|
|
||||||
self._pool_size(writer)
|
|
||||||
elif cmd == constants.CMD_FUNC:
|
|
||||||
self._pool_func(writer)
|
|
||||||
else:
|
|
||||||
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
|
|
||||||
writer.write(b"Invalid command!")
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
log.debug("%s connected", self.client_class.__name__)
|
session = ControlSession(self, reader, writer)
|
||||||
writer.write(str(self._pool).encode())
|
await session.client_handshake()
|
||||||
await writer.drain()
|
await session.listen()
|
||||||
await self._listen(reader, writer)
|
|
||||||
|
|
||||||
async def _serve_forever(self) -> None:
|
async def _serve_forever(self) -> None:
|
||||||
try:
|
try:
|
||||||
@ -135,7 +88,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
|||||||
|
|
||||||
|
|
||||||
class UnixControlServer(ControlServer):
|
class UnixControlServer(ControlServer):
|
||||||
client_class = UnixControlClient
|
_client_class = UnixControlClient
|
||||||
|
|
||||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||||
self._socket_path = Path(server_kwargs.pop('path'))
|
self._socket_path = Path(server_kwargs.pop('path'))
|
||||||
|
218
src/asyncio_taskpool/session.py
Normal file
218
src/asyncio_taskpool/session.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace
|
||||||
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
|
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
from . import constants
|
||||||
|
from .exceptions import HelpRequested
|
||||||
|
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
||||||
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .server import ControlServer
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NUM = 'num'
|
||||||
|
WIDTH = 'width'
|
||||||
|
|
||||||
|
|
||||||
|
class CommandParser(ArgumentParser):
|
||||||
|
@staticmethod
|
||||||
|
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
|
||||||
|
class ClientHelpFormatter(HelpFormatter):
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
kwargs[WIDTH] = terminal_width
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
return ClientHelpFormatter
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
parent: CommandParser = kwargs.pop('parent', None)
|
||||||
|
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
|
||||||
|
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
|
||||||
|
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
|
||||||
|
kwargs.setdefault('exit_on_error', False)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream_writer(self) -> StreamWriter:
|
||||||
|
return self._stream_writer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def terminal_width(self) -> int:
|
||||||
|
return self._terminal_width
|
||||||
|
|
||||||
|
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||||
|
if message:
|
||||||
|
self.stream_writer.write(message.encode())
|
||||||
|
|
||||||
|
def exit(self, status: int = 0, message: str = None) -> None:
|
||||||
|
if message:
|
||||||
|
self._print_message(message)
|
||||||
|
|
||||||
|
def print_help(self, file=None) -> None:
|
||||||
|
super().print_help(file)
|
||||||
|
raise HelpRequested
|
||||||
|
|
||||||
|
|
||||||
|
class ControlSession:
|
||||||
|
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
self._control_server: 'ControlServer' = server
|
||||||
|
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
|
||||||
|
self._client_class_name = server.client_class_name
|
||||||
|
self._reader: StreamReader = reader
|
||||||
|
self._writer: StreamWriter = writer
|
||||||
|
self._parser: Optional[CommandParser] = None
|
||||||
|
|
||||||
|
def _add_base_commands(self):
|
||||||
|
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD)
|
||||||
|
subparsers.add_parser(
|
||||||
|
constants.CMD_NAME,
|
||||||
|
prog=constants.CMD_NAME,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.__str__),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
subparser_pool_size = subparsers.add_parser(
|
||||||
|
constants.CMD_POOL_SIZE,
|
||||||
|
prog=constants.CMD_POOL_SIZE,
|
||||||
|
help="Get/set the maximum number of tasks in the pool",
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
subparser_pool_size.add_argument(
|
||||||
|
NUM,
|
||||||
|
nargs='?',
|
||||||
|
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
||||||
|
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
||||||
|
)
|
||||||
|
subparsers.add_parser(
|
||||||
|
constants.CMD_NUM_RUNNING,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.num_running.fget),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
return subparsers
|
||||||
|
|
||||||
|
def _add_simple_commands(self):
|
||||||
|
subparsers = self._add_base_commands()
|
||||||
|
subparser = subparsers.add_parser(
|
||||||
|
constants.CMD_START,
|
||||||
|
prog=constants.CMD_START,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.start),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
subparser.add_argument(
|
||||||
|
NUM,
|
||||||
|
nargs='?',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of tasks to start. Defaults to 1."
|
||||||
|
)
|
||||||
|
subparser = subparsers.add_parser(
|
||||||
|
constants.CMD_STOP,
|
||||||
|
prog=constants.CMD_STOP,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.stop),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
subparser.add_argument(
|
||||||
|
NUM,
|
||||||
|
nargs='?',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of tasks to stop. Defaults to 1."
|
||||||
|
)
|
||||||
|
subparsers.add_parser(
|
||||||
|
constants.CMD_STOP_ALL,
|
||||||
|
prog=constants.CMD_STOP_ALL,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.stop_all),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
subparsers.add_parser(
|
||||||
|
constants.CMD_FUNC_NAME,
|
||||||
|
prog=constants.CMD_FUNC_NAME,
|
||||||
|
help=get_first_doc_line(self._pool.__class__.func_name.fget),
|
||||||
|
parent=self._parser,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_parser(self, client_terminal_width: int) -> None:
|
||||||
|
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width)
|
||||||
|
if isinstance(self._pool, TaskPool):
|
||||||
|
pass # TODO
|
||||||
|
elif isinstance(self._pool, SimpleTaskPool):
|
||||||
|
self._add_simple_commands()
|
||||||
|
|
||||||
|
async def client_handshake(self) -> None:
|
||||||
|
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip())
|
||||||
|
log.debug("%s connected", self._client_class_name)
|
||||||
|
self._init_parser(client_info[WIDTH])
|
||||||
|
self._writer.write(str(self._pool).encode())
|
||||||
|
await self._writer.drain()
|
||||||
|
|
||||||
|
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
|
||||||
|
output = await return_or_exception(func, *args, **kwargs)
|
||||||
|
self._writer.write(b"ok" if output is None else str(output).encode())
|
||||||
|
|
||||||
|
async def _cmd_name(self, **_kwargs) -> None:
|
||||||
|
log.debug("%s requests task pool name", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.__str__, self._pool)
|
||||||
|
|
||||||
|
async def _cmd_pool_size(self, **kwargs) -> None:
|
||||||
|
num = kwargs.get(NUM)
|
||||||
|
if num is None:
|
||||||
|
log.debug("%s requests pool size", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
|
||||||
|
else:
|
||||||
|
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
|
||||||
|
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num))
|
||||||
|
|
||||||
|
async def _cmd_num_running(self, **_kwargs) -> None:
|
||||||
|
log.debug("%s requests number of running tasks", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
|
||||||
|
|
||||||
|
async def _cmd_start(self, **kwargs) -> None:
|
||||||
|
num = kwargs[NUM]
|
||||||
|
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
|
||||||
|
await self._write_function_output(self._pool.start, num)
|
||||||
|
|
||||||
|
async def _cmd_stop(self, **kwargs) -> None:
|
||||||
|
num = kwargs[NUM]
|
||||||
|
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
|
||||||
|
await self._write_function_output(self._pool.stop, num)
|
||||||
|
|
||||||
|
async def _cmd_stop_all(self, **_kwargs) -> None:
|
||||||
|
log.debug("%s requests stopping all tasks", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.stop_all)
|
||||||
|
|
||||||
|
async def _cmd_func_name(self, **_kwargs) -> None:
|
||||||
|
log.debug("%s requests pool function name", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
|
||||||
|
|
||||||
|
async def _execute_command(self, args: Namespace) -> None:
|
||||||
|
args = vars(args)
|
||||||
|
cmd: str = args.pop(constants.CMD, None)
|
||||||
|
if cmd is not None:
|
||||||
|
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
|
||||||
|
await method(**args)
|
||||||
|
|
||||||
|
async def _parse_command(self, msg: str) -> None:
|
||||||
|
try:
|
||||||
|
args, argv = self._parser.parse_known_args(msg.split(' '))
|
||||||
|
except ArgumentError as e:
|
||||||
|
self._writer.write(str(e).encode())
|
||||||
|
return
|
||||||
|
except HelpRequested:
|
||||||
|
return
|
||||||
|
if argv:
|
||||||
|
log.debug("%s sent unknown arguments: %s", self._client_class_name, msg)
|
||||||
|
self._writer.write(b"Invalid command!")
|
||||||
|
return
|
||||||
|
await self._execute_command(args)
|
||||||
|
|
||||||
|
async def listen(self) -> None:
|
||||||
|
while self._control_server.is_serving():
|
||||||
|
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip()
|
||||||
|
if not msg:
|
||||||
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
|
break
|
||||||
|
await self._parse_command(msg)
|
||||||
|
await self._writer.drain()
|
@ -28,7 +28,7 @@ T = TypeVar('T')
|
|||||||
ArgsT = Iterable[Any]
|
ArgsT = Iterable[Any]
|
||||||
KwArgsT = Mapping[str, Any]
|
KwArgsT = Mapping[str, Any]
|
||||||
|
|
||||||
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
|
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
||||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||||
|
|
||||||
EndCallbackT = Callable
|
EndCallbackT = Callable
|
||||||
|
@ -86,3 +86,11 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_queue = MagicMock(join=mock_join)
|
mock_queue = MagicMock(join=mock_join)
|
||||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
self.assertIsNone(await helpers.join_queue(mock_queue))
|
||||||
mock_join.assert_awaited_once_with()
|
mock_join.assert_awaited_once_with()
|
||||||
|
|
||||||
|
def test_task_str(self):
|
||||||
|
self.assertEqual("task", helpers.tasks_str(1))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(0))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(-1))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(2))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(-10))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(42))
|
||||||
|
Loading…
Reference in New Issue
Block a user