generated from daniil-berg/boilerplate-py
Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
538b9cc91c | |||
3fb451a00e |
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 0.3.0
|
version = 0.3.2
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -26,8 +26,8 @@ from abc import ABC, abstractmethod
|
|||||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from asyncio_taskpool import constants
|
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||||
from asyncio_taskpool.types import ClientConnT
|
from .types import ClientConnT
|
||||||
|
|
||||||
|
|
||||||
class ControlClient(ABC):
|
class ControlClient(ABC):
|
||||||
@ -38,7 +38,7 @@ class ControlClient(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def client_info() -> dict:
|
def client_info() -> dict:
|
||||||
return {'width': shutil.get_terminal_size().columns}
|
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||||
|
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
self._conn_kwargs = conn_kwargs
|
self._conn_kwargs = conn_kwargs
|
||||||
@ -48,17 +48,17 @@ class ControlClient(ABC):
|
|||||||
self._connected = True
|
self._connected = True
|
||||||
writer.write(json.dumps(self.client_info()).encode())
|
writer.write(json.dumps(self.client_info()).encode())
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
msg = input("> ").strip().lower()
|
||||||
except EOFError:
|
except EOFError:
|
||||||
msg = constants.CLIENT_EXIT
|
msg = CLIENT_EXIT
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print()
|
print()
|
||||||
return
|
return
|
||||||
if msg == constants.CLIENT_EXIT:
|
if msg == CLIENT_EXIT:
|
||||||
writer.close()
|
writer.close()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
return
|
return
|
||||||
@ -69,7 +69,7 @@ class ControlClient(ABC):
|
|||||||
self._connected = False
|
self._connected = False
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
return
|
return
|
||||||
print((await reader.read(constants.MSG_BYTES)).decode())
|
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
reader, writer = await self.open_connection(**self._conn_kwargs)
|
||||||
|
@ -20,13 +20,25 @@ Constants used by more than one module in the package.
|
|||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
MSG_BYTES = 1024000
|
|
||||||
CMD = 'command'
|
|
||||||
CMD_NAME = 'name'
|
|
||||||
CMD_POOL_SIZE = 'pool-size'
|
|
||||||
CMD_NUM_RUNNING = 'num-running'
|
|
||||||
CMD_START = 'start'
|
|
||||||
CMD_STOP = 'stop'
|
|
||||||
CMD_STOP_ALL = 'stop-all'
|
|
||||||
CMD_FUNC_NAME = 'func-name'
|
|
||||||
CLIENT_EXIT = 'exit'
|
CLIENT_EXIT = 'exit'
|
||||||
|
|
||||||
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
SESSION_PARSER_WRITER = 'session_writer'
|
||||||
|
|
||||||
|
|
||||||
|
class CLIENT_INFO:
|
||||||
|
__slots__ = ()
|
||||||
|
TERMINAL_WIDTH = 'terminal_width'
|
||||||
|
|
||||||
|
|
||||||
|
class CMD:
|
||||||
|
__slots__ = ()
|
||||||
|
CMD = 'command'
|
||||||
|
NAME = 'name'
|
||||||
|
POOL_SIZE = 'pool-size'
|
||||||
|
NUM_RUNNING = 'num-running'
|
||||||
|
START = 'start'
|
||||||
|
STOP = 'stop'
|
||||||
|
STOP_ALL = 'stop-all'
|
||||||
|
FUNC_NAME = 'func-name'
|
||||||
|
@ -31,74 +31,121 @@ from typing import Optional, Union
|
|||||||
from .client import ControlClient, UnixControlClient
|
from .client import ControlClient, UnixControlClient
|
||||||
from .pool import TaskPool, SimpleTaskPool
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
from .session import ControlSession
|
from .session import ControlSession
|
||||||
|
from .types import ConnectedCallbackT
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||||
|
"""
|
||||||
|
Abstract base class for a task pool control server.
|
||||||
|
|
||||||
|
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
|
||||||
|
connecting to it. The entire interface is defined within that session class.
|
||||||
|
"""
|
||||||
_client_class = ControlClient
|
_client_class = ControlClient
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@property
|
@property
|
||||||
def client_class_name(cls) -> str:
|
def client_class_name(cls) -> str:
|
||||||
|
"""Returns the name of the control client class matching the server class."""
|
||||||
return cls._client_class.__name__
|
return cls._client_class.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||||
|
"""
|
||||||
|
Initializes, starts, and returns an async server instance (Unix or TCP type).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_connected_cb:
|
||||||
|
The callback for when a client connects to the server (as per `asyncio.start_server` or
|
||||||
|
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
|
||||||
|
**kwargs (optional):
|
||||||
|
Keyword arguments to pass into the function that starts the server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The running server object (a type of `asyncio.Server`).
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def final_callback(self) -> None:
|
def _final_callback(self) -> None:
|
||||||
|
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Initializes by merely saving the internal attributes, but without starting the server yet.
|
||||||
|
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
|
||||||
|
tied to one specific task pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool:
|
||||||
|
An instance of a `BaseTaskPool` subclass to tie the server to.
|
||||||
|
**server_kwargs (optional):
|
||||||
|
Keyword arguments that will be passed into the function that starts the server.
|
||||||
|
"""
|
||||||
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
||||||
self._server_kwargs = server_kwargs
|
self._server_kwargs = server_kwargs
|
||||||
self._server: Optional[AbstractServer] = None
|
self._server: Optional[AbstractServer] = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__} for {self._pool}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
||||||
|
"""Read-only property for accessing the task pool instance controlled by the server."""
|
||||||
return self._pool
|
return self._pool
|
||||||
|
|
||||||
def is_serving(self) -> bool:
|
def is_serving(self) -> bool:
|
||||||
|
"""Wrapper around the `asyncio.Server.is_serving` method."""
|
||||||
return self._server.is_serving()
|
return self._server.is_serving()
|
||||||
|
|
||||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
The universal client callback that will be passed into the `_get_server_instance` method.
|
||||||
|
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
|
||||||
|
"""
|
||||||
session = ControlSession(self, reader, writer)
|
session = ControlSession(self, reader, writer)
|
||||||
await session.client_handshake()
|
await session.client_handshake()
|
||||||
await session.listen()
|
await session.listen()
|
||||||
|
|
||||||
async def _serve_forever(self) -> None:
|
async def _serve_forever(self) -> None:
|
||||||
|
"""
|
||||||
|
To be run as an `asyncio.Task` by the following method.
|
||||||
|
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
|
||||||
|
method is called, when the former method ends for whatever reason.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with self._server:
|
async with self._server:
|
||||||
await self._server.serve_forever()
|
await self._server.serve_forever()
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
log.debug("%s stopped", self.__class__.__name__)
|
log.debug("%s stopped", self.__class__.__name__)
|
||||||
finally:
|
finally:
|
||||||
self.final_callback()
|
self._final_callback()
|
||||||
|
|
||||||
async def serve_forever(self) -> Task:
|
async def serve_forever(self) -> Task:
|
||||||
|
"""
|
||||||
|
This method actually starts the server and begins listening to client connections on the specified interface.
|
||||||
|
It should never block because the serving will be performed in a separate task.
|
||||||
|
"""
|
||||||
log.debug("Starting %s...", self.__class__.__name__)
|
log.debug("Starting %s...", self.__class__.__name__)
|
||||||
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||||
return create_task(self._serve_forever())
|
return create_task(self._serve_forever())
|
||||||
|
|
||||||
|
|
||||||
class UnixControlServer(ControlServer):
|
class UnixControlServer(ControlServer):
|
||||||
|
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||||
_client_class = UnixControlClient
|
_client_class = UnixControlClient
|
||||||
|
|
||||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||||
self._socket_path = Path(server_kwargs.pop('path'))
|
self._socket_path = Path(server_kwargs.pop('path'))
|
||||||
super().__init__(pool, **server_kwargs)
|
super().__init__(pool, **server_kwargs)
|
||||||
|
|
||||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||||
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||||
return srv
|
return server
|
||||||
|
|
||||||
def final_callback(self) -> None:
|
def _final_callback(self) -> None:
|
||||||
|
"""Removes the unix socket on which the server was listening."""
|
||||||
self._socket_path.unlink()
|
self._socket_path.unlink()
|
||||||
log.debug("Removed socket '%s'", str(self._socket_path))
|
log.debug("Removed socket '%s'", str(self._socket_path))
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace
|
from argparse import ArgumentError, HelpFormatter, Namespace
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
|
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from . import constants
|
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||||
from .exceptions import HelpRequested
|
from .exceptions import HelpRequested
|
||||||
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
||||||
from .pool import TaskPool, SimpleTaskPool
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
|
from .session_parser import CommandParser, NUM
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .server import ControlServer
|
from .server import ControlServer
|
||||||
@ -15,47 +16,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
NUM = 'num'
|
|
||||||
WIDTH = 'width'
|
|
||||||
|
|
||||||
|
|
||||||
class CommandParser(ArgumentParser):
|
|
||||||
@staticmethod
|
|
||||||
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
|
|
||||||
class ClientHelpFormatter(HelpFormatter):
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
kwargs[WIDTH] = terminal_width
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
return ClientHelpFormatter
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
parent: CommandParser = kwargs.pop('parent', None)
|
|
||||||
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
|
|
||||||
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
|
|
||||||
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
|
|
||||||
kwargs.setdefault('exit_on_error', False)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stream_writer(self) -> StreamWriter:
|
|
||||||
return self._stream_writer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def terminal_width(self) -> int:
|
|
||||||
return self._terminal_width
|
|
||||||
|
|
||||||
def _print_message(self, message: str, *args, **kwargs) -> None:
|
|
||||||
if message:
|
|
||||||
self.stream_writer.write(message.encode())
|
|
||||||
|
|
||||||
def exit(self, status: int = 0, message: str = None) -> None:
|
|
||||||
if message:
|
|
||||||
self._print_message(message)
|
|
||||||
|
|
||||||
def print_help(self, file=None) -> None:
|
|
||||||
super().print_help(file)
|
|
||||||
raise HelpRequested
|
|
||||||
|
|
||||||
|
|
||||||
class ControlSession:
|
class ControlSession:
|
||||||
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
|
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
@ -65,86 +25,65 @@ class ControlSession:
|
|||||||
self._reader: StreamReader = reader
|
self._reader: StreamReader = reader
|
||||||
self._writer: StreamWriter = writer
|
self._writer: StreamWriter = writer
|
||||||
self._parser: Optional[CommandParser] = None
|
self._parser: Optional[CommandParser] = None
|
||||||
|
self._subparsers = None
|
||||||
|
|
||||||
def _add_base_commands(self):
|
def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
|
||||||
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD)
|
**kwargs) -> CommandParser:
|
||||||
subparsers.add_parser(
|
if prog is None:
|
||||||
constants.CMD_NAME,
|
prog = name
|
||||||
prog=constants.CMD_NAME,
|
kwargs.setdefault('help', short_help or long_help)
|
||||||
help=get_first_doc_line(self._pool.__class__.__str__),
|
kwargs.setdefault('description', long_help or short_help)
|
||||||
parent=self._parser,
|
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
|
||||||
)
|
|
||||||
subparser_pool_size = subparsers.add_parser(
|
def _add_base_commands(self) -> None:
|
||||||
constants.CMD_POOL_SIZE,
|
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
|
||||||
prog=constants.CMD_POOL_SIZE,
|
self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
|
||||||
help="Get/set the maximum number of tasks in the pool",
|
self._add_parser_command(
|
||||||
parent=self._parser,
|
CMD.POOL_SIZE,
|
||||||
)
|
short_help="Get/set the maximum number of tasks in the pool.",
|
||||||
subparser_pool_size.add_argument(
|
formatter_class=HelpFormatter
|
||||||
NUM,
|
).add_optional_num_argument(
|
||||||
nargs='?',
|
default=None,
|
||||||
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
||||||
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
||||||
)
|
)
|
||||||
subparsers.add_parser(
|
self._add_parser_command(
|
||||||
constants.CMD_NUM_RUNNING,
|
CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)
|
||||||
help=get_first_doc_line(self._pool.__class__.num_running.fget),
|
|
||||||
parent=self._parser,
|
|
||||||
)
|
)
|
||||||
return subparsers
|
|
||||||
|
|
||||||
def _add_simple_commands(self):
|
def _add_simple_commands(self) -> None:
|
||||||
subparsers = self._add_base_commands()
|
self._add_parser_command(
|
||||||
subparser = subparsers.add_parser(
|
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
|
||||||
constants.CMD_START,
|
).add_optional_num_argument(
|
||||||
prog=constants.CMD_START,
|
help="Number of tasks to start."
|
||||||
help=get_first_doc_line(self._pool.__class__.start),
|
|
||||||
parent=self._parser,
|
|
||||||
)
|
)
|
||||||
subparser.add_argument(
|
self._add_parser_command(
|
||||||
NUM,
|
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
|
||||||
nargs='?',
|
).add_optional_num_argument(
|
||||||
type=int,
|
help="Number of tasks to stop."
|
||||||
default=1,
|
|
||||||
help="Number of tasks to start. Defaults to 1."
|
|
||||||
)
|
)
|
||||||
subparser = subparsers.add_parser(
|
self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
|
||||||
constants.CMD_STOP,
|
self._add_parser_command(
|
||||||
prog=constants.CMD_STOP,
|
CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
|
||||||
help=get_first_doc_line(self._pool.__class__.stop),
|
|
||||||
parent=self._parser,
|
|
||||||
)
|
|
||||||
subparser.add_argument(
|
|
||||||
NUM,
|
|
||||||
nargs='?',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of tasks to stop. Defaults to 1."
|
|
||||||
)
|
|
||||||
subparsers.add_parser(
|
|
||||||
constants.CMD_STOP_ALL,
|
|
||||||
prog=constants.CMD_STOP_ALL,
|
|
||||||
help=get_first_doc_line(self._pool.__class__.stop_all),
|
|
||||||
parent=self._parser,
|
|
||||||
)
|
|
||||||
subparsers.add_parser(
|
|
||||||
constants.CMD_FUNC_NAME,
|
|
||||||
prog=constants.CMD_FUNC_NAME,
|
|
||||||
help=get_first_doc_line(self._pool.__class__.func_name.fget),
|
|
||||||
parent=self._parser,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_parser(self, client_terminal_width: int) -> None:
|
def _init_parser(self, client_terminal_width: int) -> None:
|
||||||
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width)
|
parser_kwargs = {
|
||||||
|
'prog': '',
|
||||||
|
SESSION_PARSER_WRITER: self._writer,
|
||||||
|
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
|
||||||
|
}
|
||||||
|
self._parser = CommandParser(**parser_kwargs)
|
||||||
|
self._add_base_commands()
|
||||||
if isinstance(self._pool, TaskPool):
|
if isinstance(self._pool, TaskPool):
|
||||||
pass # TODO
|
pass # TODO
|
||||||
elif isinstance(self._pool, SimpleTaskPool):
|
elif isinstance(self._pool, SimpleTaskPool):
|
||||||
self._add_simple_commands()
|
self._add_simple_commands()
|
||||||
|
|
||||||
async def client_handshake(self) -> None:
|
async def client_handshake(self) -> None:
|
||||||
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip())
|
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
||||||
log.debug("%s connected", self._client_class_name)
|
log.debug("%s connected", self._client_class_name)
|
||||||
self._init_parser(client_info[WIDTH])
|
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
|
||||||
self._writer.write(str(self._pool).encode())
|
self._writer.write(str(self._pool).encode())
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
|
||||||
@ -163,7 +102,7 @@ class ControlSession:
|
|||||||
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
|
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
|
||||||
else:
|
else:
|
||||||
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
|
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
|
||||||
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num))
|
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
|
||||||
|
|
||||||
async def _cmd_num_running(self, **_kwargs) -> None:
|
async def _cmd_num_running(self, **_kwargs) -> None:
|
||||||
log.debug("%s requests number of running tasks", self._client_class_name)
|
log.debug("%s requests number of running tasks", self._client_class_name)
|
||||||
@ -189,7 +128,7 @@ class ControlSession:
|
|||||||
|
|
||||||
async def _execute_command(self, args: Namespace) -> None:
|
async def _execute_command(self, args: Namespace) -> None:
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
cmd: str = args.pop(constants.CMD, None)
|
cmd: str = args.pop(CMD.CMD, None)
|
||||||
if cmd is not None:
|
if cmd is not None:
|
||||||
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
|
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
|
||||||
await method(**args)
|
await method(**args)
|
||||||
@ -210,7 +149,7 @@ class ControlSession:
|
|||||||
|
|
||||||
async def listen(self) -> None:
|
async def listen(self) -> None:
|
||||||
while self._control_server.is_serving():
|
while self._control_server.is_serving():
|
||||||
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip()
|
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
||||||
if not msg:
|
if not msg:
|
||||||
log.debug("%s disconnected", self._client_class_name)
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
break
|
break
|
||||||
|
60
src/asyncio_taskpool/session_parser.py
Normal file
60
src/asyncio_taskpool/session_parser.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
|
||||||
|
from asyncio.streams import StreamWriter
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
|
||||||
|
from .exceptions import HelpRequested
|
||||||
|
|
||||||
|
|
||||||
|
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||||
|
FORMATTER_CLASS = 'formatter_class'
|
||||||
|
NUM = 'num'
|
||||||
|
|
||||||
|
|
||||||
|
class CommandParser(ArgumentParser):
|
||||||
|
@staticmethod
|
||||||
|
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
|
||||||
|
if base_cls is None:
|
||||||
|
base_cls = ArgumentDefaultsHelpFormatter
|
||||||
|
|
||||||
|
class ClientHelpFormatter(base_cls):
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
kwargs['width'] = terminal_width
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
return ClientHelpFormatter
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
parent: CommandParser = kwargs.pop('parent', None)
|
||||||
|
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
|
||||||
|
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
|
||||||
|
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
|
||||||
|
kwargs.setdefault('exit_on_error', False)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream_writer(self) -> StreamWriter:
|
||||||
|
return self._stream_writer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def terminal_width(self) -> int:
|
||||||
|
return self._terminal_width
|
||||||
|
|
||||||
|
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||||
|
if message:
|
||||||
|
self.stream_writer.write(message.encode())
|
||||||
|
|
||||||
|
def exit(self, status: int = 0, message: str = None) -> None:
|
||||||
|
if message:
|
||||||
|
self._print_message(message)
|
||||||
|
|
||||||
|
def print_help(self, file=None) -> None:
|
||||||
|
super().print_help(file)
|
||||||
|
raise HelpRequested
|
||||||
|
|
||||||
|
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
|
||||||
|
if not name_or_flags:
|
||||||
|
name_or_flags = (NUM, )
|
||||||
|
kwargs.setdefault('nargs', '?')
|
||||||
|
kwargs.setdefault('default', 1)
|
||||||
|
kwargs.setdefault('type', int)
|
||||||
|
return self.add_argument(*name_or_flags, **kwargs)
|
@ -34,4 +34,5 @@ CoroutineFunc = Callable[[...], Awaitable[Any]]
|
|||||||
EndCallbackT = Callable
|
EndCallbackT = Callable
|
||||||
CancelCallbackT = Callable
|
CancelCallbackT = Callable
|
||||||
|
|
||||||
|
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||||
|
143
tests/test_server.py
Normal file
143
tests/test_server.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool import server
|
||||||
|
from asyncio_taskpool.client import ControlClient, UnixControlClient
|
||||||
|
|
||||||
|
|
||||||
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = server.log.level
|
||||||
|
server.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
server.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
|
||||||
|
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||||
|
self.mock_pool = MagicMock()
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.abstract_patcher.stop()
|
||||||
|
|
||||||
|
def test_client_class_name(self):
|
||||||
|
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
|
||||||
|
|
||||||
|
async def test_abstract(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
args = [AsyncMock()]
|
||||||
|
await self.server._get_server_instance(*args)
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
self.server._final_callback()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.mock_pool, self.server._pool)
|
||||||
|
self.assertEqual(self.kwargs, self.server._server_kwargs)
|
||||||
|
self.assertIsNone(self.server._server)
|
||||||
|
|
||||||
|
def test_pool(self):
|
||||||
|
self.assertEqual(self.mock_pool, self.server.pool)
|
||||||
|
|
||||||
|
def test_is_serving(self):
|
||||||
|
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
|
||||||
|
self.assertEqual(FOO + BAR, self.server.is_serving())
|
||||||
|
|
||||||
|
@patch.object(server, 'ControlSession')
|
||||||
|
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
|
||||||
|
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
|
||||||
|
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
|
||||||
|
mock_reader, mock_writer = MagicMock(), MagicMock()
|
||||||
|
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
|
||||||
|
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
|
||||||
|
mock_client_handshake.assert_awaited_once_with()
|
||||||
|
mock_listen.assert_awaited_once_with()
|
||||||
|
|
||||||
|
@patch.object(server.ControlServer, '_final_callback')
|
||||||
|
async def test__serve_forever(self, mock__final_callback: MagicMock):
|
||||||
|
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
|
||||||
|
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
|
||||||
|
with self.assertLogs(server.log, logging.DEBUG):
|
||||||
|
self.assertIsNone(await self.server._serve_forever())
|
||||||
|
mock_aenter.assert_awaited_once_with()
|
||||||
|
mock_serve_forever.assert_awaited_once_with()
|
||||||
|
mock__final_callback.assert_called_once_with()
|
||||||
|
|
||||||
|
mock_aenter.reset_mock()
|
||||||
|
mock_serve_forever.reset_mock(side_effect=True)
|
||||||
|
mock__final_callback.reset_mock()
|
||||||
|
|
||||||
|
self.assertIsNone(await self.server._serve_forever())
|
||||||
|
mock_aenter.assert_awaited_once_with()
|
||||||
|
mock_serve_forever.assert_awaited_once_with()
|
||||||
|
mock__final_callback.assert_called_once_with()
|
||||||
|
|
||||||
|
@patch.object(server, 'create_task')
|
||||||
|
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
|
||||||
|
@patch.object(server.ControlServer, '_get_server_instance')
|
||||||
|
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
|
||||||
|
mock_create_task: MagicMock):
|
||||||
|
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
|
||||||
|
mock_create_task.return_value = expected_output = 12345
|
||||||
|
output = await self.server.serve_forever()
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
|
||||||
|
mock__serve_forever.assert_called_once_with()
|
||||||
|
mock_create_task.assert_called_once_with(mock_awaitable)
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = server.log.level
|
||||||
|
server.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
server.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
|
||||||
|
self.mock_base_init = self.base_init_patcher.start()
|
||||||
|
self.mock_pool = MagicMock()
|
||||||
|
self.path = '/tmp/asyncio_taskpool'
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.base_init_patcher.stop()
|
||||||
|
|
||||||
|
def test__client_class(self):
|
||||||
|
self.assertEqual(UnixControlClient, self.server._client_class)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(Path(self.path), self.server._socket_path)
|
||||||
|
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
|
||||||
|
|
||||||
|
@patch.object(server, 'start_unix_server')
|
||||||
|
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
|
||||||
|
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
|
||||||
|
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
|
||||||
|
args = [mock_callback]
|
||||||
|
output = await self.server._get_server_instance(*args, **mock_kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
|
||||||
|
|
||||||
|
def test__final_callback(self):
|
||||||
|
self.server._socket_path = MagicMock()
|
||||||
|
self.assertIsNone(self.server._final_callback())
|
||||||
|
self.server._socket_path.unlink.assert_called_once_with()
|
Reference in New Issue
Block a user