Compare commits

..

8 Commits

Author SHA1 Message Date
96d01e7259 full unit test coverage and docstrings for session module 2022-02-14 17:59:11 +01:00
3f3eb7ce38 minor rephrasing 2022-02-13 21:12:35 +01:00
05d51eface simplified method 2022-02-13 20:58:36 +01:00
b6aed727e9 additional unit tests 2022-02-13 19:55:27 +01:00
c9a8d9ecd1 better placeholders 2022-02-13 19:45:53 +01:00
538b9cc91c major refactoring of the control session/parser classes; restructured constants 2022-02-13 19:39:21 +01:00
3fb451a00e docstrings and full test coverage for server module; small adjustments 2022-02-13 16:17:50 +01:00
be03097bf4 massive overhaul of the control server interface;
making use of `ArgumentParser` for client commands;
new `ControlSession` object instantiated upon connection
2022-02-12 22:51:52 +01:00
15 changed files with 1053 additions and 95 deletions

View File

@@ -8,5 +8,8 @@ omit =
fail_under = 100
show_missing = True
skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
omit =
tests/*

View File

@@ -14,7 +14,7 @@ If you need control over a task pool at runtime, you can launch an asynchronous
## Usage
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like:
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
```python
from asyncio_taskpool import SimpleTaskPool
...

View File

@@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.2.1
version = 0.3.3
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@@ -19,13 +19,15 @@ Classes of control clients for a simply interface to a task pool control server.
"""
import json
import shutil
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from asyncio_taskpool import constants
from asyncio_taskpool.types import ClientConnT
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT
class ControlClient(ABC):
@@ -34,19 +36,29 @@ class ControlClient(ABC):
async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError
@staticmethod
def client_info() -> dict:
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try:
msg = input("> ").strip().lower()
except EOFError:
msg = constants.CLIENT_EXIT
msg = CLIENT_EXIT
except KeyboardInterrupt:
print()
return
if msg == constants.CLIENT_EXIT:
if msg == CLIENT_EXIT:
writer.close()
self._connected = False
return
@@ -57,15 +69,14 @@ class ControlClient(ABC):
self._connected = False
print(e, file=sys.stderr)
return
print((await reader.read(constants.MSG_BYTES)).decode())
print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
self._connected = True
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
await self._server_handshake(reader, writer)
while self._connected:
await self._interact(reader, writer)
print("Disconnected from control server.")

View File

@@ -20,10 +20,25 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all'
CMD_NUM_RUNNING = 'num_running'
CMD_FUNC = 'func'
CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
SESSION_PARSER_WRITER = 'session_writer'
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'
class CMD:
__slots__ = ()
CMD = 'command'
NAME = 'name'
POOL_SIZE = 'pool-size'
NUM_RUNNING = 'num-running'
START = 'start'
STOP = 'stop'
STOP_ALL = 'stop-all'
FUNC_NAME = 'func-name'

View File

@@ -49,3 +49,19 @@ class PoolStillUnlocked(PoolException):
class NotCoroutine(PoolException):
pass
class ServerException(Exception):
pass
class UnknownTaskPoolClass(ServerException):
pass
class NotATaskPool(ServerException):
pass
class HelpRequested(ServerException):
pass

View File

@@ -21,7 +21,8 @@ Miscellaneous helper functions.
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from typing import Any, Optional
from inspect import getdoc
from typing import Any, Optional, Union
from .types import T, AnyCallableT, ArgsT, KwArgsT
@@ -48,3 +49,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
async def join_queue(q: Queue) -> None:
await q.join()
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
try:
if iscoroutinefunction(_function_to_execute):
return await _function_to_execute(*args, **kwargs)
else:
return _function_to_execute(*args, **kwargs)
except Exception as e:
return e

View File

@@ -78,6 +78,7 @@ class BaseTaskPool:
log.debug("%s initialized", str(self))
def __str__(self) -> str:
"""Returns the name of the task pool."""
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property

View File

@@ -26,126 +26,126 @@ from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Tuple, Union, Optional
from typing import Optional, Union
from . import constants
from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession
from .types import ConnectedCallbackT
log = logging.getLogger(__name__)
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
cmd = msg.strip().split(' ', 1)
if len(cmd) > 1:
try:
return cmd[0], int(cmd[1])
except ValueError:
return None, None
return cmd[0], None
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient
"""
Abstract base class for a task pool control server.
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
connecting to it. The entire interface is defined within that session class.
"""
_client_class = ControlClient
@classmethod
@property
def client_class_name(cls) -> str:
"""Returns the name of the control client class matching the server class."""
return cls._client_class.__name__
@abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
"""
Initializes, starts, and returns an async server instance (Unix or TCP type).
Args:
client_connected_cb:
The callback for when a client connects to the server (as per `asyncio.start_server` or
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
**kwargs (optional):
Keyword arguments to pass into the function that starts the server.
Returns:
The running server object (a type of `asyncio.Server`).
"""
raise NotImplementedError
@abstractmethod
def final_callback(self) -> None:
def _final_callback(self) -> None:
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
raise NotImplementedError
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._pool: SimpleTaskPool = pool
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
"""
Initializes by merely saving the internal attributes, but without starting the server yet.
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
tied to one specific task pool.
Args:
pool:
An instance of a `BaseTaskPool` subclass to tie the server to.
**server_kwargs (optional):
Keyword arguments that will be passed into the function that starts the server.
"""
self._pool: Union[TaskPool, SimpleTaskPool] = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(await self._pool.start(num)).encode())
@property
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
"""Read-only property for accessing the task pool instance controlled by the server."""
return self._pool
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
# the requested number may be greater than the total number of running tasks
writer.write(str(self._pool.stop(num)).encode())
def _stop_all_tasks(self, writer: StreamWriter) -> None:
log.debug("%s requests stopping all tasks", self.client_class.__name__)
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests number of running tasks", self.client_class.__name__)
writer.write(str(self._pool.num_running).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
writer.write(self._pool.func_name.encode())
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
while self._server.is_serving():
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self.client_class.__name__)
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_NUM_RUNNING:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)
else:
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
writer.write(b"Invalid command!")
await writer.drain()
def is_serving(self) -> bool:
"""Wrapper around the `asyncio.Server.is_serving` method."""
return self._server.is_serving()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__)
writer.write(str(self._pool).encode())
await writer.drain()
await self._listen(reader, writer)
"""
The universal client callback that will be passed into the `_get_server_instance` method.
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
"""
session = ControlSession(self, reader, writer)
await session.client_handshake()
await session.listen()
async def _serve_forever(self) -> None:
"""
To be run as an `asyncio.Task` by the following method.
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
method is called, when the former method ends for whatever reason.
"""
try:
async with self._server:
await self._server.serve_forever()
except CancelledError:
log.debug("%s stopped", self.__class__.__name__)
finally:
self.final_callback()
self._final_callback()
async def serve_forever(self) -> Task:
"""
This method actually starts the server and begins listening to client connections on the specified interface.
It should never block because the serving will be performed in a separate task.
"""
log.debug("Starting %s...", self.__class__.__name__)
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
return create_task(self._serve_forever())
class UnixControlServer(ControlServer):
client_class = UnixControlClient
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs)
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path))
return srv
return server
def final_callback(self) -> None:
def _final_callback(self) -> None:
"""Removes the unix socket on which the server was listening."""
self._socket_path.unlink()
log.debug("Removed socket '%s'", str(self._socket_path))

View File

@@ -0,0 +1,304 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the control session class used by the control server.
"""
import logging
import json
from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM
if TYPE_CHECKING:
from .server import ControlServer
log = logging.getLogger(__name__)
class ControlSession:
"""
This class defines the API for controlling a task pool instance from the outside.
The commands received from a connected client are translated into method calls on the task pool instance.
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
"""
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
"""
Instantiation should happen once a client connection to the control server has already been established.
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
Args:
server:
The instance of a `ControlServer` subclass starting the session.
reader:
The `asyncio.StreamReader` created when a client connected to the server.
writer:
The `asyncio.StreamWriter` created when a client connected to the server.
"""
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None
self._subparsers = None
def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
**kwargs) -> CommandParser:
"""
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
Args:
name:
The command name; passed directly into the `add_parser` method.
prog (optional):
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
equal to the `name` argument.
short_help (optional):
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
long_help (optional):
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
**kwargs (optional):
Any keyword-arguments to directly pass into the `add_parser` method.
Returns:
An instance of the `CommandParser` class representing the newly added control command.
"""
if prog is None:
prog = name
kwargs.setdefault('help', short_help or long_help)
kwargs.setdefault('description', long_help or short_help)
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
def _add_base_commands(self) -> None:
"""
Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
These include commands mapping to the following pool methods:
- __str__
- pool_size (get/set property)
- num_running
"""
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
self._add_command(
CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter
).add_optional_num_argument(
default=None,
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
)
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
def _add_simple_commands(self) -> None:
"""
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
These include commands mapping to the following pool methods:
- start
- stop
- stop_all
- func_name
"""
self._add_command(
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
).add_optional_num_argument(
help="Number of tasks to start."
)
self._add_command(
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
).add_optional_num_argument(
help="Number of tasks to stop."
)
self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
def _add_advanced_commands(self) -> None:
"""
Adds the commands that are only supported, if a `TaskPool` object is controlled.
These include commands mapping to the following pool methods:
- ...
"""
raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None:
"""
Initializes and fully configures the `CommandParser` responsible for handling the input.
Depending on what specific task pool class is controlled by the server, different commands are added.
Args:
client_terminal_width:
The number of columns of the client's terminal to be able to nicely format messages from the parser.
"""
parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
self._add_base_commands()
if isinstance(self._pool, TaskPool):
self._add_advanced_commands()
elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands()
elif isinstance(self._pool, BaseTaskPool):
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
else:
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
async def client_handshake(self) -> None:
"""
This method must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
"""
Acts as a wrapper around a call to a specific task pool method.
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
the exception caught is sent back.
Args:
func:
Reference to the task pool method.
*args (optional):
Any positional arguments to call the method with.
*+kwargs (optional):
Any keyword-arguments to call the method with.
"""
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
"""Maps to the `__str__` method of any task pool class."""
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
"""Maps to the `pool_size` property of any task pool class."""
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None:
"""Maps to the `num_running` property of any task pool class."""
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None:
"""Maps to the `start` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, **kwargs) -> None:
"""
Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
Args:
**kwargs:
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
simply passed into the method determined from the command name.
"""
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
await method(**kwargs)
async def _parse_command(self, msg: str) -> None:
"""
Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
Args:
msg:
The non-empty string read from the client stream.
"""
try:
kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
self._writer.write(str(e).encode())
return
except HelpRequested:
return
await self._execute_command(**kwargs)
async def listen(self) -> None:
"""
Enters the main control loop that only ends if either the server or the client disconnect.
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
This method should be called, when the client connection was established and the handshake was successful.
It will obviously block indefinitely.
"""
while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
await self._writer.drain()

View File

@@ -0,0 +1,60 @@
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter
from typing import Type, TypeVar
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
from .exceptions import HelpRequested
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
FORMATTER_CLASS = 'formatter_class'
NUM = 'num'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
if not name_or_flags:
name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?')
kwargs.setdefault('default', 1)
kwargs.setdefault('type', int)
return self.add_argument(*name_or_flags, **kwargs)

View File

@@ -28,10 +28,11 @@ T = TypeVar('T')
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable
CancelCallbackT = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

View File

@@ -86,3 +86,43 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_task_str(self):
self.assertEqual("task", helpers.tasks_str(1))
self.assertEqual("tasks", helpers.tasks_str(0))
self.assertEqual("tasks", helpers.tasks_str(-1))
self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42))
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}
something else
even more
""")
output = helpers.get_first_doc_line(mock_obj)
self.assertEqual(expected_output, output)
async def test_return_or_exception(self):
expected_output = '420'
mock_func = AsyncMock(return_value=expected_output)
args = (1, 3, 5)
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_awaited_once_with(*args, **kwargs)
mock_func = MagicMock(return_value=expected_output)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
class TestException(Exception):
pass
test_exception = TestException()
mock_func = MagicMock(side_effect=test_exception)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(test_exception, output)
mock_func.assert_called_once_with(*args, **kwargs)

164
tests/test_server.py Normal file
View File

@@ -0,0 +1,164 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.server` module.
"""
import asyncio
import logging
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import server
from asyncio_taskpool.client import ControlClient, UnixControlClient
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_pool = MagicMock()
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
def tearDown(self) -> None:
self.abstract_patcher.stop()
def test_client_class_name(self):
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
async def test_abstract(self):
with self.assertRaises(NotImplementedError):
args = [AsyncMock()]
await self.server._get_server_instance(*args)
with self.assertRaises(NotImplementedError):
self.server._final_callback()
def test_init(self):
self.assertEqual(self.mock_pool, self.server._pool)
self.assertEqual(self.kwargs, self.server._server_kwargs)
self.assertIsNone(self.server._server)
def test_pool(self):
self.assertEqual(self.mock_pool, self.server.pool)
def test_is_serving(self):
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
self.assertEqual(FOO + BAR, self.server.is_serving())
@patch.object(server, 'ControlSession')
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
mock_reader, mock_writer = MagicMock(), MagicMock()
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
mock_client_handshake.assert_awaited_once_with()
mock_listen.assert_awaited_once_with()
@patch.object(server.ControlServer, '_final_callback')
async def test__serve_forever(self, mock__final_callback: MagicMock):
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
with self.assertLogs(server.log, logging.DEBUG):
self.assertIsNone(await self.server._serve_forever())
mock_aenter.assert_awaited_once_with()
mock_serve_forever.assert_awaited_once_with()
mock__final_callback.assert_called_once_with()
mock_aenter.reset_mock()
mock_serve_forever.reset_mock(side_effect=True)
mock__final_callback.reset_mock()
self.assertIsNone(await self.server._serve_forever())
mock_aenter.assert_awaited_once_with()
mock_serve_forever.assert_awaited_once_with()
mock__final_callback.assert_called_once_with()
@patch.object(server, 'create_task')
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
@patch.object(server.ControlServer, '_get_server_instance')
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
mock_create_task: MagicMock):
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
mock_create_task.return_value = expected_output = 12345
output = await self.server.serve_forever()
self.assertEqual(expected_output, output)
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
mock__serve_forever.assert_called_once_with()
mock_create_task.assert_called_once_with(mock_awaitable)
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.mock_pool = MagicMock()
self.path = '/tmp/asyncio_taskpool'
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test__client_class(self):
self.assertEqual(UnixControlClient, self.server._client_class)
def test_init(self):
self.assertEqual(Path(self.path), self.server._socket_path)
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
@patch.object(server, 'start_unix_server')
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
args = [mock_callback]
output = await self.server._get_server_instance(*args, **mock_kwargs)
self.assertEqual(expected_output, output)
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
def test__final_callback(self):
self.server._socket_path = MagicMock()
self.assertIsNone(self.server._final_callback())
self.server._socket_path.unlink.assert_called_once_with()

324
tests/test_session.py Normal file
View File

@@ -0,0 +1,324 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session` module.
"""
import json
from argparse import ArgumentError, Namespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = session.log.level
session.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
session.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
self.mock_client_class_name = FOO + BAR
self.mock_server = MagicMock(pool=self.mock_pool,
client_class_name=self.mock_client_class_name)
self.mock_reader = MagicMock()
self.mock_writer = MagicMock()
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
def test_init(self):
self.assertEqual(self.mock_server, self.session._control_server)
self.assertEqual(self.mock_pool, self.session._pool)
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
self.assertIsNone(self.session._subparsers)
def test__add_command(self):
expected_output = 123456
mock_add_parser = MagicMock(return_value=expected_output)
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
self.session._parser = MagicMock()
name, prog, short_help, long_help = 'abc', None, 'short123', None
kwargs = {'x': 1, 'y': 2}
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
prog, long_help = 'ffffff', 'so long, wow'
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
short_help = None
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
parent=self.session._parser, **kwargs)
@patch.object(session, 'get_first_doc_line')
@patch.object(session.ControlSession, '_add_command')
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
self.assertIsNone(self.session._add_base_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
mock__add_command.reset_mock()
mock_get_first_doc_line.reset_mock()
self.assertIsNone(self.session._add_simple_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
with self.assertRaises(NotImplementedError):
self.session._add_advanced_commands()
@patch.object(session.ControlSession, '_add_simple_commands')
@patch.object(session.ControlSession, '_add_advanced_commands')
@patch.object(session.ControlSession, '_add_base_commands')
@patch.object(session, 'CommandParser')
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
mock_command_parser_cls.return_value = mock_parser = MagicMock()
self.session._pool = TaskPool()
width = 1234
expected_parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
}
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_called_once_with()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
async def fake_coroutine(): pass
self.session._pool = SimpleTaskPool(fake_coroutine)
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_called_once_with()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
class FakeTaskPool(BaseTaskPool):
pass
self.session._pool = FakeTaskPool()
with self.assertRaises(UnknownTaskPoolClass):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
self.session._pool = MagicMock()
with self.assertRaises(NotATaskPool):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
@patch.object(session.ControlSession, '_init_parser')
async def test_client_handshake(self, mock__init_parser: MagicMock):
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
self.assertIsNone(await self.session.client_handshake())
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock__init_parser.assert_called_once_with(width)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session, 'return_or_exception')
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
self.mock_writer.write = MagicMock()
mock_return_or_exception.return_value = None
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(b"ok")
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_return_or_exception.return_value = output = MagicMock()
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(str(output).encode())
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_name())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
)
mock__write_function_output.reset_mock()
kwargs.pop(session.NUM)
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_num_running())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.num_running.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_start(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_stop(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_stop_all())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_func_name())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.func_name.fget, self.session._pool
)
async def test__execute_command(self):
mock_method = AsyncMock()
cmd = 'this-is-a-test'
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
kwargs = {FOO: BAR, 'hello': 'python'}
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
mock_method.assert_awaited_once_with(**kwargs)
@patch.object(session.ControlSession, '_execute_command')
async def test__parse_command(self, mock__execute_command: AsyncMock):
msg = 'asdf asd as a'
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_awaited_once_with(**kwargs)
mock__execute_command.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__execute_command.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty():
self.mock_reader.read.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg)
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.drain.assert_not_awaited()