generated from daniil-berg/boilerplate-py
Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
4994135062 | |||
d0c0177681 | |||
bac7b32342 | |||
96d01e7259 | |||
3f3eb7ce38 | |||
05d51eface | |||
b6aed727e9 | |||
c9a8d9ecd1 | |||
538b9cc91c | |||
3fb451a00e | |||
be03097bf4 |
@ -8,5 +8,8 @@ omit =
|
||||
fail_under = 100
|
||||
show_missing = True
|
||||
skip_covered = False
|
||||
exclude_lines =
|
||||
if TYPE_CHECKING:
|
||||
if __name__ == ['"]__main__['"]:
|
||||
omit =
|
||||
tests/*
|
||||
|
@ -14,7 +14,7 @@ If you need control over a task pool at runtime, you can launch an asynchronous
|
||||
|
||||
## Usage
|
||||
|
||||
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like:
|
||||
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
|
||||
```python
|
||||
from asyncio_taskpool import SimpleTaskPool
|
||||
...
|
||||
|
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.2.1
|
||||
version = 0.3.5
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -19,64 +19,150 @@ Classes of control clients for a simply interface to a task pool control server.
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from asyncio_taskpool import constants
|
||||
from asyncio_taskpool.types import ClientConnT
|
||||
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||
from .types import ClientConnT, PathT
|
||||
|
||||
|
||||
class ControlClient(ABC):
|
||||
"""
|
||||
Abstract base class for a simple implementation of a task pool control client.
|
||||
|
||||
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
|
||||
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
|
||||
This is a minimal working implementation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def client_info() -> dict:
|
||||
"""Returns a dictionary of client information relevant for the handshake with the server."""
|
||||
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||
|
||||
@abstractmethod
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||
"""
|
||||
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
|
||||
|
||||
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
|
||||
as keyword-arguments.
|
||||
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
|
||||
`None` and `None`, if it failed to establish the defined connection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
|
||||
self._conn_kwargs = conn_kwargs
|
||||
self._connected: bool = False
|
||||
|
||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
Performs the first interaction with the server providing it with the necessary client information.
|
||||
|
||||
Upon completion, the server's info is printed.
|
||||
|
||||
Args:
|
||||
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
"""
|
||||
self._connected = True
|
||||
writer.write(json.dumps(self.client_info()).encode())
|
||||
await writer.drain()
|
||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||
|
||||
def _get_command(self, writer: StreamWriter) -> Optional[str]:
|
||||
"""
|
||||
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
|
||||
|
||||
Args:
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
|
||||
Returns:
|
||||
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
||||
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
||||
"""
|
||||
try:
|
||||
msg = input("> ").strip().lower()
|
||||
except EOFError:
|
||||
msg = constants.CLIENT_EXIT
|
||||
except KeyboardInterrupt:
|
||||
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
|
||||
msg = CLIENT_EXIT
|
||||
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||
print()
|
||||
return
|
||||
if msg == constants.CLIENT_EXIT:
|
||||
if msg == CLIENT_EXIT:
|
||||
writer.close()
|
||||
self._connected = False
|
||||
return
|
||||
return msg
|
||||
|
||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
|
||||
|
||||
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
|
||||
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
|
||||
|
||||
Args:
|
||||
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
"""
|
||||
cmd = self._get_command(writer)
|
||||
if cmd is None:
|
||||
return
|
||||
try:
|
||||
writer.write(msg.encode())
|
||||
# Send the command to the server.
|
||||
writer.write(cmd.encode())
|
||||
await writer.drain()
|
||||
except ConnectionError as e:
|
||||
self._connected = False
|
||||
print(e, file=sys.stderr)
|
||||
return
|
||||
print((await reader.read(constants.MSG_BYTES)).decode())
|
||||
# Await the server's response, then print it.
|
||||
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
||||
|
||||
async def start(self):
|
||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
|
||||
|
||||
If the connection can not be established, an error message is printed to `stderr` and the method returns.
|
||||
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
|
||||
disconnected-message.
|
||||
"""
|
||||
reader, writer = await self._open_connection(**self._conn_kwargs)
|
||||
if reader is None:
|
||||
print("Failed to connect.", file=sys.stderr)
|
||||
return
|
||||
self._connected = True
|
||||
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
||||
await self._server_handshake(reader, writer)
|
||||
while self._connected:
|
||||
await self._interact(reader, writer)
|
||||
print("Disconnected from control server.")
|
||||
|
||||
|
||||
class UnixControlClient(ControlClient):
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
self._socket_path = Path(conn_kwargs.pop('path'))
|
||||
"""Task pool control client that expects a unix socket to be exposed by the control server."""
|
||||
|
||||
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
|
||||
"""
|
||||
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
|
||||
|
||||
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
|
||||
"""
|
||||
self._socket_path = Path(socket_path)
|
||||
super().__init__(**conn_kwargs)
|
||||
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||
"""
|
||||
Wrapper around the `asyncio.open_unix_connection` function.
|
||||
|
||||
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
|
||||
otherwise, the stream-reader and -writer tuple is returned.
|
||||
"""
|
||||
try:
|
||||
return await open_unix_connection(self._socket_path, **kwargs)
|
||||
except FileNotFoundError:
|
||||
|
@ -20,10 +20,25 @@ Constants used by more than one module in the package.
|
||||
|
||||
|
||||
PACKAGE_NAME = 'asyncio_taskpool'
|
||||
MSG_BYTES = 1024
|
||||
CMD_START = 'start'
|
||||
CMD_STOP = 'stop'
|
||||
CMD_STOP_ALL = 'stop_all'
|
||||
CMD_NUM_RUNNING = 'num_running'
|
||||
CMD_FUNC = 'func'
|
||||
|
||||
CLIENT_EXIT = 'exit'
|
||||
|
||||
SESSION_MSG_BYTES = 1024 * 100
|
||||
SESSION_WRITER = 'session_writer'
|
||||
|
||||
|
||||
class CLIENT_INFO:
|
||||
__slots__ = ()
|
||||
TERMINAL_WIDTH = 'terminal_width'
|
||||
|
||||
|
||||
class CMD:
|
||||
__slots__ = ()
|
||||
CMD = 'command'
|
||||
NAME = 'name'
|
||||
POOL_SIZE = 'pool-size'
|
||||
NUM_RUNNING = 'num-running'
|
||||
START = 'start'
|
||||
STOP = 'stop'
|
||||
STOP_ALL = 'stop-all'
|
||||
FUNC_NAME = 'func-name'
|
||||
|
@ -49,3 +49,19 @@ class PoolStillUnlocked(PoolException):
|
||||
|
||||
class NotCoroutine(PoolException):
|
||||
pass
|
||||
|
||||
|
||||
class ServerException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownTaskPoolClass(ServerException):
|
||||
pass
|
||||
|
||||
|
||||
class NotATaskPool(ServerException):
|
||||
pass
|
||||
|
||||
|
||||
class HelpRequested(ServerException):
|
||||
pass
|
||||
|
@ -21,7 +21,8 @@ Miscellaneous helper functions.
|
||||
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.queues import Queue
|
||||
from typing import Any, Optional
|
||||
from inspect import getdoc
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||
|
||||
@ -48,3 +49,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
||||
|
||||
async def join_queue(q: Queue) -> None:
|
||||
await q.join()
|
||||
|
||||
|
||||
def tasks_str(num: int) -> str:
|
||||
return "tasks" if num != 1 else "task"
|
||||
|
||||
|
||||
def get_first_doc_line(obj: object) -> str:
|
||||
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||
|
||||
|
||||
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
|
||||
try:
|
||||
if iscoroutinefunction(_function_to_execute):
|
||||
return await _function_to_execute(*args, **kwargs)
|
||||
else:
|
||||
return _function_to_execute(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
@ -78,6 +78,7 @@ class BaseTaskPool:
|
||||
log.debug("%s initialized", str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns the name of the task pool."""
|
||||
return f'{self.__class__.__name__}-{self._name or self._idx}'
|
||||
|
||||
@property
|
||||
@ -318,8 +319,8 @@ class BaseTaskPool:
|
||||
"""
|
||||
Cancels all tasks still running within the pool.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||
|
||||
@ -354,8 +355,8 @@ class BaseTaskPool:
|
||||
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||
blocking this method. Otherwise it will wait until they all have started.
|
||||
@ -551,10 +552,10 @@ class TaskPool(BaseTaskPool):
|
||||
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
|
||||
"""
|
||||
Helper function for `_map()`.
|
||||
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
|
||||
Takes the iterable of function arguments `args_iter` and adds up to `group_size` to a new `asyncio.Queue`.
|
||||
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
|
||||
|
||||
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
|
||||
If the iterable contains less than `group_size` elements, nothing else happens; otherwise the `_queue_producer`
|
||||
is started as a separate task with the arguments queue and and iterator of the remaining arguments.
|
||||
|
||||
Args:
|
||||
@ -566,7 +567,7 @@ class TaskPool(BaseTaskPool):
|
||||
Returns:
|
||||
The newly created and filled arguments queue for spawning new tasks.
|
||||
"""
|
||||
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
|
||||
# Setting the `maxsize` of the queue to `group_size` will ensure that no more than `group_size` tasks will run
|
||||
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
|
||||
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
|
||||
args_queue = Queue(maxsize=num_tasks)
|
||||
@ -574,12 +575,12 @@ class TaskPool(BaseTaskPool):
|
||||
args_iter = iter(args_iter)
|
||||
try:
|
||||
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
|
||||
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
|
||||
# tasks, which will be at most `group_size` (meaning the queue will be full).
|
||||
for i in range(num_tasks):
|
||||
args_queue.put_nowait(next(args_iter))
|
||||
except StopIteration:
|
||||
# If we get here, this means that the number of elements in the arguments iterator was less than the
|
||||
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
|
||||
# specified `group_size`. Still, the number of tasks to start immediately will be the size of the queue.
|
||||
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
|
||||
pass
|
||||
else:
|
||||
@ -590,17 +591,27 @@ class TaskPool(BaseTaskPool):
|
||||
create_task(self._queue_producer(args_queue, args_iter))
|
||||
return args_queue
|
||||
|
||||
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
|
||||
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
|
||||
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
|
||||
and referring to "group size" rather than chunk/batch size.
|
||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
|
||||
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `args_iter`.
|
||||
|
||||
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. Assuming the number of elements produced by `args_iter` is greater than `group_size`,
|
||||
this method will block **only** until the first `group_size` tasks have been **started**, before returning.
|
||||
(If the number of elements from `args_iter` is smaller than `group_size`, this method will return as soon as
|
||||
all of them have been started.)
|
||||
|
||||
As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
|
||||
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
|
||||
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
|
||||
equal to `group_size` (except for when `args_iter` is exhausted of course).
|
||||
|
||||
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
|
||||
This method sets up an internal arguments queue which is continuously filled while consuming the `args_iter`.
|
||||
|
||||
Args:
|
||||
func:
|
||||
@ -609,8 +620,8 @@ class TaskPool(BaseTaskPool):
|
||||
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
|
||||
arg_stars (optional):
|
||||
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
|
||||
num_tasks (optional):
|
||||
The maximum number of the new tasks to run concurrently.
|
||||
group_size (optional):
|
||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -623,7 +634,7 @@ class TaskPool(BaseTaskPool):
|
||||
"""
|
||||
if not self._locked:
|
||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||
args_queue = self._set_up_args_queue(args_iter, num_tasks)
|
||||
args_queue = self._set_up_args_queue(args_iter, group_size)
|
||||
# We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
|
||||
# `_queue_callback` triggered by one or more of them.
|
||||
# This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more
|
||||
@ -637,29 +648,34 @@ class TaskPool(BaseTaskPool):
|
||||
# Now the callbacks can immediately trigger more tasks.
|
||||
first_batch_started.set()
|
||||
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||
|
||||
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
|
||||
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable.
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
||||
|
||||
Once the first batch of tasks has started to run, this method returns.
|
||||
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool)
|
||||
consuming the next element from the arguments iterable.
|
||||
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number
|
||||
of tasks from this call concurrently running within the pool.
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. Assuming the number of elements produced by `arg_iter` is greater than `group_size`,
|
||||
this method will block **only** until the first `group_size` tasks have been **started**, before returning.
|
||||
(If the number of elements from `arg_iter` is smaller than `group_size`, this method will return as soon as
|
||||
all of them have been started.)
|
||||
|
||||
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
|
||||
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
|
||||
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
|
||||
equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
||||
|
||||
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
|
||||
Args:
|
||||
func:
|
||||
The coroutine function to use for spawning the new tasks within the task pool.
|
||||
arg_iter:
|
||||
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
|
||||
num_tasks (optional):
|
||||
The maximum number of the new tasks to run concurrently.
|
||||
group_size (optional):
|
||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -671,27 +687,27 @@ class TaskPool(BaseTaskPool):
|
||||
`PoolIsLocked` if the pool has been locked.
|
||||
`NotCoroutine` if `func` is not a coroutine function.
|
||||
"""
|
||||
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
await self._map(func, arg_iter, arg_stars=0, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
|
||||
positional arguments to the function.
|
||||
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`.
|
||||
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
||||
"""
|
||||
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
await self._map(func, args_iter, arg_stars=1, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
|
||||
keyword-arguments to the function.
|
||||
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`.
|
||||
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
||||
"""
|
||||
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
await self._map(func, kwargs_iter, arg_stars=2, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
|
||||
|
@ -26,126 +26,126 @@ from asyncio.exceptions import CancelledError
|
||||
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
||||
from asyncio.tasks import Task, create_task
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from . import constants
|
||||
from .pool import SimpleTaskPool
|
||||
from .client import ControlClient, UnixControlClient
|
||||
from .pool import TaskPool, SimpleTaskPool
|
||||
from .session import ControlSession
|
||||
from .types import ConnectedCallbackT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tasks_str(num: int) -> str:
|
||||
return "tasks" if num != 1 else "task"
|
||||
|
||||
|
||||
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
|
||||
cmd = msg.strip().split(' ', 1)
|
||||
if len(cmd) > 1:
|
||||
try:
|
||||
return cmd[0], int(cmd[1])
|
||||
except ValueError:
|
||||
return None, None
|
||||
return cmd[0], None
|
||||
|
||||
|
||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||
client_class = ControlClient
|
||||
"""
|
||||
Abstract base class for a task pool control server.
|
||||
|
||||
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
|
||||
connecting to it. The entire interface is defined within that session class.
|
||||
"""
|
||||
_client_class = ControlClient
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def client_class_name(cls) -> str:
|
||||
"""Returns the name of the control client class matching the server class."""
|
||||
return cls._client_class.__name__
|
||||
|
||||
@abstractmethod
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
"""
|
||||
Initializes, starts, and returns an async server instance (Unix or TCP type).
|
||||
|
||||
Args:
|
||||
client_connected_cb:
|
||||
The callback for when a client connects to the server (as per `asyncio.start_server` or
|
||||
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
|
||||
**kwargs (optional):
|
||||
Keyword arguments to pass into the function that starts the server.
|
||||
|
||||
Returns:
|
||||
The running server object (a type of `asyncio.Server`).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def final_callback(self) -> None:
|
||||
def _final_callback(self) -> None:
|
||||
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
self._pool: SimpleTaskPool = pool
|
||||
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||
"""
|
||||
Initializes by merely saving the internal attributes, but without starting the server yet.
|
||||
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
|
||||
tied to one specific task pool.
|
||||
|
||||
Args:
|
||||
pool:
|
||||
An instance of a `BaseTaskPool` subclass to tie the server to.
|
||||
**server_kwargs (optional):
|
||||
Keyword arguments that will be passed into the function that starts the server.
|
||||
"""
|
||||
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
||||
self._server_kwargs = server_kwargs
|
||||
self._server: Optional[AbstractServer] = None
|
||||
|
||||
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||
if num is None:
|
||||
num = 1
|
||||
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
|
||||
writer.write(str(await self._pool.start(num)).encode())
|
||||
@property
|
||||
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
||||
"""Read-only property for accessing the task pool instance controlled by the server."""
|
||||
return self._pool
|
||||
|
||||
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||
if num is None:
|
||||
num = 1
|
||||
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
|
||||
# the requested number may be greater than the total number of running tasks
|
||||
writer.write(str(self._pool.stop(num)).encode())
|
||||
|
||||
def _stop_all_tasks(self, writer: StreamWriter) -> None:
|
||||
log.debug("%s requests stopping all tasks", self.client_class.__name__)
|
||||
writer.write(str(self._pool.stop_all()).encode())
|
||||
|
||||
def _pool_size(self, writer: StreamWriter) -> None:
|
||||
log.debug("%s requests number of running tasks", self.client_class.__name__)
|
||||
writer.write(str(self._pool.num_running).encode())
|
||||
|
||||
def _pool_func(self, writer: StreamWriter) -> None:
|
||||
log.debug("%s requests pool function", self.client_class.__name__)
|
||||
writer.write(self._pool.func_name.encode())
|
||||
|
||||
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
while self._server.is_serving():
|
||||
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
|
||||
if not msg:
|
||||
log.debug("%s disconnected", self.client_class.__name__)
|
||||
break
|
||||
cmd, arg = get_cmd_arg(msg)
|
||||
if cmd == constants.CMD_START:
|
||||
await self._start_tasks(writer, arg)
|
||||
elif cmd == constants.CMD_STOP:
|
||||
self._stop_tasks(writer, arg)
|
||||
elif cmd == constants.CMD_STOP_ALL:
|
||||
self._stop_all_tasks(writer)
|
||||
elif cmd == constants.CMD_NUM_RUNNING:
|
||||
self._pool_size(writer)
|
||||
elif cmd == constants.CMD_FUNC:
|
||||
self._pool_func(writer)
|
||||
else:
|
||||
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
|
||||
writer.write(b"Invalid command!")
|
||||
await writer.drain()
|
||||
def is_serving(self) -> bool:
|
||||
"""Wrapper around the `asyncio.Server.is_serving` method."""
|
||||
return self._server.is_serving()
|
||||
|
||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
log.debug("%s connected", self.client_class.__name__)
|
||||
writer.write(str(self._pool).encode())
|
||||
await writer.drain()
|
||||
await self._listen(reader, writer)
|
||||
"""
|
||||
The universal client callback that will be passed into the `_get_server_instance` method.
|
||||
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
|
||||
"""
|
||||
session = ControlSession(self, reader, writer)
|
||||
await session.client_handshake()
|
||||
await session.listen()
|
||||
|
||||
async def _serve_forever(self) -> None:
|
||||
"""
|
||||
To be run as an `asyncio.Task` by the following method.
|
||||
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
|
||||
method is called, when the former method ends for whatever reason.
|
||||
"""
|
||||
try:
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
except CancelledError:
|
||||
log.debug("%s stopped", self.__class__.__name__)
|
||||
finally:
|
||||
self.final_callback()
|
||||
self._final_callback()
|
||||
|
||||
async def serve_forever(self) -> Task:
|
||||
"""
|
||||
This method actually starts the server and begins listening to client connections on the specified interface.
|
||||
It should never block because the serving will be performed in a separate task.
|
||||
"""
|
||||
log.debug("Starting %s...", self.__class__.__name__)
|
||||
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||
return create_task(self._serve_forever())
|
||||
|
||||
|
||||
class UnixControlServer(ControlServer):
|
||||
client_class = UnixControlClient
|
||||
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||
_client_class = UnixControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
self._socket_path = Path(server_kwargs.pop('path'))
|
||||
super().__init__(pool, **server_kwargs)
|
||||
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||
return srv
|
||||
return server
|
||||
|
||||
def final_callback(self) -> None:
|
||||
def _final_callback(self) -> None:
|
||||
"""Removes the unix socket on which the server was listening."""
|
||||
self._socket_path.unlink()
|
||||
log.debug("Removed socket '%s'", str(self._socket_path))
|
||||
|
304
src/asyncio_taskpool/session.py
Normal file
304
src/asyncio_taskpool/session.py
Normal file
@ -0,0 +1,304 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
This module contains the the definition of the control session class used by the control server.
|
||||
"""
|
||||
|
||||
|
||||
import logging
|
||||
import json
|
||||
from argparse import ArgumentError, HelpFormatter
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
||||
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||
from .session_parser import CommandParser, NUM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import ControlServer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlSession:
|
||||
"""
|
||||
This class defines the API for controlling a task pool instance from the outside.
|
||||
|
||||
The commands received from a connected client are translated into method calls on the task pool instance.
|
||||
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
|
||||
"""
|
||||
|
||||
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
Instantiation should happen once a client connection to the control server has already been established.
|
||||
|
||||
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
|
||||
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
|
||||
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
|
||||
|
||||
Args:
|
||||
server:
|
||||
The instance of a `ControlServer` subclass starting the session.
|
||||
reader:
|
||||
The `asyncio.StreamReader` created when a client connected to the server.
|
||||
writer:
|
||||
The `asyncio.StreamWriter` created when a client connected to the server.
|
||||
"""
|
||||
self._control_server: 'ControlServer' = server
|
||||
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
|
||||
self._client_class_name = server.client_class_name
|
||||
self._reader: StreamReader = reader
|
||||
self._writer: StreamWriter = writer
|
||||
self._parser: Optional[CommandParser] = None
|
||||
self._subparsers = None
|
||||
|
||||
def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
|
||||
**kwargs) -> CommandParser:
|
||||
"""
|
||||
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
|
||||
|
||||
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
|
||||
|
||||
Args:
|
||||
name:
|
||||
The command name; passed directly into the `add_parser` method.
|
||||
prog (optional):
|
||||
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
|
||||
equal to the `name` argument.
|
||||
short_help (optional):
|
||||
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
|
||||
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
|
||||
long_help (optional):
|
||||
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
|
||||
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
|
||||
**kwargs (optional):
|
||||
Any keyword-arguments to directly pass into the `add_parser` method.
|
||||
|
||||
Returns:
|
||||
An instance of the `CommandParser` class representing the newly added control command.
|
||||
"""
|
||||
if prog is None:
|
||||
prog = name
|
||||
kwargs.setdefault('help', short_help or long_help)
|
||||
kwargs.setdefault('description', long_help or short_help)
|
||||
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
|
||||
|
||||
def _add_base_commands(self) -> None:
|
||||
"""
|
||||
Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
|
||||
|
||||
These include commands mapping to the following pool methods:
|
||||
- __str__
|
||||
- pool_size (get/set property)
|
||||
- num_running
|
||||
"""
|
||||
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
|
||||
self._add_command(
|
||||
CMD.POOL_SIZE,
|
||||
short_help="Get/set the maximum number of tasks in the pool.",
|
||||
formatter_class=HelpFormatter
|
||||
).add_optional_num_argument(
|
||||
default=None,
|
||||
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
||||
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
||||
)
|
||||
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
|
||||
|
||||
def _add_simple_commands(self) -> None:
|
||||
"""
|
||||
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
|
||||
|
||||
These include commands mapping to the following pool methods:
|
||||
- start
|
||||
- stop
|
||||
- stop_all
|
||||
- func_name
|
||||
"""
|
||||
self._add_command(
|
||||
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
|
||||
).add_optional_num_argument(
|
||||
help="Number of tasks to start."
|
||||
)
|
||||
self._add_command(
|
||||
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
|
||||
).add_optional_num_argument(
|
||||
help="Number of tasks to stop."
|
||||
)
|
||||
self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
|
||||
self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
|
||||
|
||||
def _add_advanced_commands(self) -> None:
|
||||
"""
|
||||
Adds the commands that are only supported, if a `TaskPool` object is controlled.
|
||||
|
||||
These include commands mapping to the following pool methods:
|
||||
- ...
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _init_parser(self, client_terminal_width: int) -> None:
|
||||
"""
|
||||
Initializes and fully configures the `CommandParser` responsible for handling the input.
|
||||
|
||||
Depending on what specific task pool class is controlled by the server, different commands are added.
|
||||
|
||||
Args:
|
||||
client_terminal_width:
|
||||
The number of columns of the client's terminal to be able to nicely format messages from the parser.
|
||||
"""
|
||||
parser_kwargs = {
|
||||
'prog': '',
|
||||
SESSION_WRITER: self._writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
|
||||
}
|
||||
self._parser = CommandParser(**parser_kwargs)
|
||||
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
|
||||
self._add_base_commands()
|
||||
if isinstance(self._pool, TaskPool):
|
||||
self._add_advanced_commands()
|
||||
elif isinstance(self._pool, SimpleTaskPool):
|
||||
self._add_simple_commands()
|
||||
elif isinstance(self._pool, BaseTaskPool):
|
||||
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
|
||||
else:
|
||||
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
|
||||
|
||||
async def client_handshake(self) -> None:
|
||||
"""
|
||||
This method must be invoked before starting any other client interaction.
|
||||
|
||||
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
|
||||
"""
|
||||
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
||||
log.debug("%s connected", self._client_class_name)
|
||||
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
|
||||
self._writer.write(str(self._pool).encode())
|
||||
await self._writer.drain()
|
||||
|
||||
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
|
||||
"""
|
||||
Acts as a wrapper around a call to a specific task pool method.
|
||||
|
||||
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
|
||||
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
|
||||
the exception caught is sent back.
|
||||
|
||||
Args:
|
||||
func:
|
||||
Reference to the task pool method.
|
||||
*args (optional):
|
||||
Any positional arguments to call the method with.
|
||||
*+kwargs (optional):
|
||||
Any keyword-arguments to call the method with.
|
||||
"""
|
||||
output = await return_or_exception(func, *args, **kwargs)
|
||||
self._writer.write(b"ok" if output is None else str(output).encode())
|
||||
|
||||
async def _cmd_name(self, **_kwargs) -> None:
|
||||
"""Maps to the `__str__` method of any task pool class."""
|
||||
log.debug("%s requests task pool name", self._client_class_name)
|
||||
await self._write_function_output(self._pool.__class__.__str__, self._pool)
|
||||
|
||||
async def _cmd_pool_size(self, **kwargs) -> None:
|
||||
"""Maps to the `pool_size` property of any task pool class."""
|
||||
num = kwargs.get(NUM)
|
||||
if num is None:
|
||||
log.debug("%s requests pool size", self._client_class_name)
|
||||
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
|
||||
else:
|
||||
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
|
||||
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
|
||||
|
||||
async def _cmd_num_running(self, **_kwargs) -> None:
|
||||
"""Maps to the `num_running` property of any task pool class."""
|
||||
log.debug("%s requests number of running tasks", self._client_class_name)
|
||||
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
|
||||
|
||||
async def _cmd_start(self, **kwargs) -> None:
|
||||
"""Maps to the `start` method of the `SimpleTaskPool` class."""
|
||||
num = kwargs[NUM]
|
||||
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
|
||||
await self._write_function_output(self._pool.start, num)
|
||||
|
||||
async def _cmd_stop(self, **kwargs) -> None:
|
||||
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
|
||||
num = kwargs[NUM]
|
||||
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
|
||||
await self._write_function_output(self._pool.stop, num)
|
||||
|
||||
async def _cmd_stop_all(self, **_kwargs) -> None:
|
||||
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
|
||||
log.debug("%s requests stopping all tasks", self._client_class_name)
|
||||
await self._write_function_output(self._pool.stop_all)
|
||||
|
||||
async def _cmd_func_name(self, **_kwargs) -> None:
|
||||
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
|
||||
log.debug("%s requests pool function name", self._client_class_name)
|
||||
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
|
||||
|
||||
async def _execute_command(self, **kwargs) -> None:
|
||||
"""
|
||||
Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
|
||||
simply passed into the method determined from the command name.
|
||||
"""
|
||||
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
|
||||
await method(**kwargs)
|
||||
|
||||
async def _parse_command(self, msg: str) -> None:
|
||||
"""
|
||||
Takes a message from the client and attempts to parse it.
|
||||
|
||||
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
|
||||
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
|
||||
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
|
||||
|
||||
Args:
|
||||
msg:
|
||||
The non-empty string read from the client stream.
|
||||
"""
|
||||
try:
|
||||
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
||||
except ArgumentError as e:
|
||||
self._writer.write(str(e).encode())
|
||||
return
|
||||
except HelpRequested:
|
||||
return
|
||||
await self._execute_command(**kwargs)
|
||||
|
||||
async def listen(self) -> None:
|
||||
"""
|
||||
Enters the main control loop that only ends if either the server or the client disconnect.
|
||||
|
||||
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
|
||||
This method should be called, when the client connection was established and the handshake was successful.
|
||||
It will obviously block indefinitely.
|
||||
"""
|
||||
while self._control_server.is_serving():
|
||||
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
||||
if not msg:
|
||||
log.debug("%s disconnected", self._client_class_name)
|
||||
break
|
||||
await self._parse_command(msg)
|
||||
await self._writer.drain()
|
127
src/asyncio_taskpool/session_parser.py
Normal file
127
src/asyncio_taskpool/session_parser.py
Normal file
@ -0,0 +1,127 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
This module contains the the definition of the `CommandParser` class used in a control server session.
|
||||
"""
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
|
||||
from asyncio.streams import StreamWriter
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from .constants import SESSION_WRITER, CLIENT_INFO
|
||||
from .exceptions import HelpRequested
|
||||
|
||||
|
||||
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||
FORMATTER_CLASS = 'formatter_class'
|
||||
NUM = 'num'
|
||||
|
||||
|
||||
class CommandParser(ArgumentParser):
|
||||
"""
|
||||
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
|
||||
|
||||
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
|
||||
instance passed to it during initialization.
|
||||
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
||||
connected client.
|
||||
Finally, it offers some convenience methods and makes use of custom exceptions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
|
||||
"""
|
||||
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
|
||||
|
||||
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
|
||||
as convenient, when making use of sub-parsers.
|
||||
|
||||
Args:
|
||||
terminal_width:
|
||||
The number of columns of the terminal to which to adjust help formatting.
|
||||
base_cls (optional):
|
||||
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
|
||||
|
||||
Returns:
|
||||
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
|
||||
"""
|
||||
if base_cls is None:
|
||||
base_cls = ArgumentDefaultsHelpFormatter
|
||||
|
||||
class ClientHelpFormatter(base_cls):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
kwargs['width'] = terminal_width
|
||||
super().__init__(*args, **kwargs)
|
||||
return ClientHelpFormatter
|
||||
|
||||
def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None:
|
||||
"""
|
||||
Sets additional internal attributes depending on whether a parent-parser was defined.
|
||||
|
||||
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
|
||||
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
||||
|
||||
Args:
|
||||
parent (optional):
|
||||
An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method
|
||||
of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present,
|
||||
the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`.
|
||||
**kwargs(optional):
|
||||
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
|
||||
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
|
||||
is empty.
|
||||
"""
|
||||
self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER)
|
||||
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
|
||||
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
|
||||
kwargs.setdefault('exit_on_error', False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def session_writer(self) -> StreamWriter:
|
||||
"""Returns the predefined stream writer object of the control session."""
|
||||
return self._session_writer
|
||||
|
||||
@property
|
||||
def terminal_width(self) -> int:
|
||||
"""Returns the predefined terminal width."""
|
||||
return self._terminal_width
|
||||
|
||||
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
|
||||
if message:
|
||||
self._session_writer.write(message.encode())
|
||||
|
||||
def exit(self, status: int = 0, message: str = None) -> None:
|
||||
"""This is overridden to prevent system exit to be invoked."""
|
||||
if message:
|
||||
self._print_message(message)
|
||||
|
||||
def print_help(self, file=None) -> None:
|
||||
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
|
||||
super().print_help(file)
|
||||
raise HelpRequested
|
||||
|
||||
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
|
||||
"""Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified."""
|
||||
if not name_or_flags:
|
||||
name_or_flags = (NUM, )
|
||||
kwargs.setdefault('nargs', '?')
|
||||
kwargs.setdefault('default', 1)
|
||||
kwargs.setdefault('type', int)
|
||||
return self.add_argument(*name_or_flags, **kwargs)
|
@ -20,6 +20,7 @@ Custom type definitions used in various modules.
|
||||
|
||||
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
||||
|
||||
|
||||
@ -28,10 +29,13 @@ T = TypeVar('T')
|
||||
ArgsT = Iterable[Any]
|
||||
KwArgsT = Mapping[str, Any]
|
||||
|
||||
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
|
||||
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
|
||||
EndCallbackT = Callable
|
||||
CancelCallbackT = Callable
|
||||
|
||||
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||
|
||||
PathT = Union[Path, str]
|
||||
|
207
tests/test_client.py
Normal file
207
tests/test_client.py
Normal file
@ -0,0 +1,207 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.client` module.
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import client
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
|
||||
self.print_patcher = patch.object(client, 'print')
|
||||
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||
self.mock_print = self.print_patcher.start()
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.client = client.ControlClient(**self.kwargs)
|
||||
|
||||
self.mock_read = AsyncMock(return_value=FOO.encode())
|
||||
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
|
||||
self.mock_reader = MagicMock(read=self.mock_read)
|
||||
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.abstract_patcher.stop()
|
||||
self.print_patcher.stop()
|
||||
|
||||
def test_client_info(self):
|
||||
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
|
||||
client.ControlClient.client_info())
|
||||
|
||||
async def test_abstract(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
await self.client._open_connection(**self.kwargs)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.kwargs, self.client._conn_kwargs)
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, 'client_info')
|
||||
async def test__server_handshake(self, mock_client_info: MagicMock):
|
||||
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||
self.assertTrue(self.client._connected)
|
||||
mock_client_info.assert_called_once_with()
|
||||
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
|
||||
|
||||
@patch.object(client, 'input')
|
||||
def test__get_command(self, mock_input: MagicMock):
|
||||
self.client._connected = True
|
||||
|
||||
mock_input.return_value = ' ' + FOO.upper() + ' '
|
||||
mock_close = MagicMock()
|
||||
mock_writer = MagicMock(close=mock_close)
|
||||
output = self.client._get_command(mock_writer)
|
||||
self.assertEqual(FOO, output)
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock_input.reset_mock()
|
||||
mock_input.side_effect = KeyboardInterrupt
|
||||
self.assertIsNone(self.client._get_command(mock_writer))
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock_input.reset_mock()
|
||||
mock_input.side_effect = EOFError
|
||||
self.assertIsNone(self.client._get_command(mock_writer))
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_called_once()
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, '_get_command')
|
||||
async def test__interact(self, mock__get_command: MagicMock):
|
||||
self.client._connected = True
|
||||
|
||||
mock__get_command.return_value = None
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_not_called()
|
||||
self.mock_drain.assert_not_awaited()
|
||||
self.mock_read.assert_not_awaited()
|
||||
self.mock_print.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||
self.mock_drain.side_effect = err = ConnectionError()
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_called_once_with(cmd.encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
self.client._connected = True
|
||||
self.mock_write.reset_mock()
|
||||
self.mock_drain.reset_mock(side_effect=True)
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_called_once_with(cmd.encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
self.mock_print.assert_called_once_with(FOO)
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, '_interact')
|
||||
@patch.object(client.ControlClient, '_server_handshake')
|
||||
@patch.object(client.ControlClient, '_open_connection')
|
||||
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
|
||||
mock__interact: AsyncMock):
|
||||
mock__open_connection.return_value = None, None
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_not_awaited()
|
||||
mock__interact.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
|
||||
|
||||
mock__open_connection.reset_mock()
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
mock__open_connection.return_value = self.mock_reader, self.mock_writer
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
mock__interact.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||
|
||||
mock__open_connection.reset_mock()
|
||||
mock__server_handshake.reset_mock()
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
self.client._connected = True
|
||||
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
|
||||
mock__interact.side_effect = disconnect
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||
|
||||
|
||||
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
|
||||
self.mock_base_init = self.base_init_patcher.start()
|
||||
self.path = '/tmp/asyncio_taskpool'
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.base_init_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(Path(self.path), self.client._socket_path)
|
||||
self.mock_base_init.assert_called_once_with(**self.kwargs)
|
||||
|
||||
@patch.object(client, 'print')
|
||||
@patch.object(client, 'open_unix_connection')
|
||||
async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock):
|
||||
mock_open_unix_connection.return_value = expected_output = 'something'
|
||||
kwargs = {'a': 1, 'b': 2}
|
||||
output = await self.client._open_connection(**kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||
mock_print.assert_not_called()
|
||||
|
||||
mock_open_unix_connection.reset_mock()
|
||||
|
||||
mock_open_unix_connection.side_effect = FileNotFoundError
|
||||
output1, output2 = await self.client._open_connection(**kwargs)
|
||||
self.assertIsNone(output1)
|
||||
self.assertIsNone(output2)
|
||||
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)
|
@ -86,3 +86,43 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
||||
mock_queue = MagicMock(join=mock_join)
|
||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
||||
mock_join.assert_awaited_once_with()
|
||||
|
||||
def test_task_str(self):
|
||||
self.assertEqual("task", helpers.tasks_str(1))
|
||||
self.assertEqual("tasks", helpers.tasks_str(0))
|
||||
self.assertEqual("tasks", helpers.tasks_str(-1))
|
||||
self.assertEqual("tasks", helpers.tasks_str(2))
|
||||
self.assertEqual("tasks", helpers.tasks_str(-10))
|
||||
self.assertEqual("tasks", helpers.tasks_str(42))
|
||||
|
||||
def test_get_first_doc_line(self):
|
||||
expected_output = 'foo bar baz'
|
||||
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
||||
something else
|
||||
|
||||
even more
|
||||
""")
|
||||
output = helpers.get_first_doc_line(mock_obj)
|
||||
self.assertEqual(expected_output, output)
|
||||
|
||||
async def test_return_or_exception(self):
|
||||
expected_output = '420'
|
||||
mock_func = AsyncMock(return_value=expected_output)
|
||||
args = (1, 3, 5)
|
||||
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
|
||||
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_func.assert_awaited_once_with(*args, **kwargs)
|
||||
|
||||
mock_func = MagicMock(return_value=expected_output)
|
||||
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_func.assert_called_once_with(*args, **kwargs)
|
||||
|
||||
class TestException(Exception):
|
||||
pass
|
||||
test_exception = TestException()
|
||||
mock_func = MagicMock(side_effect=test_exception)
|
||||
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||
self.assertEqual(test_exception, output)
|
||||
mock_func.assert_called_once_with(*args, **kwargs)
|
||||
|
@ -558,19 +558,19 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
|
||||
|
||||
mock_func, stars = MagicMock(), 3
|
||||
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
||||
args_iter, group_size = (FOO, BAR, 1, 2, 3), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
|
||||
self.task_pool._locked = False
|
||||
with self.assertRaises(exceptions.PoolIsLocked):
|
||||
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
|
||||
await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)
|
||||
mock__set_up_args_queue.assert_not_called()
|
||||
mock__queue_consumer.assert_not_awaited()
|
||||
mock_flag_set.assert_not_called()
|
||||
|
||||
self.task_pool._locked = True
|
||||
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
|
||||
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
|
||||
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb))
|
||||
mock__set_up_args_queue.assert_called_once_with(args_iter, group_size)
|
||||
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)])
|
||||
mock_flag_set.assert_called_once_with()
|
||||
@ -578,28 +578,28 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_map(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
||||
arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_starmap(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
args_iter, num_tasks = ([FOO], [BAR]), 2
|
||||
args_iter, group_size = ([FOO], [BAR]), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_doublestarmap(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
|
||||
kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
|
||||
|
164
tests/test_server.py
Normal file
164
tests/test_server.py
Normal file
@ -0,0 +1,164 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.server` module.
|
||||
"""
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import server
|
||||
from asyncio_taskpool.client import ControlClient, UnixControlClient
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = server.log.level
|
||||
server.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
server.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
|
||||
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||
self.mock_pool = MagicMock()
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.abstract_patcher.stop()
|
||||
|
||||
def test_client_class_name(self):
|
||||
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
|
||||
|
||||
async def test_abstract(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
args = [AsyncMock()]
|
||||
await self.server._get_server_instance(*args)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.server._final_callback()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.mock_pool, self.server._pool)
|
||||
self.assertEqual(self.kwargs, self.server._server_kwargs)
|
||||
self.assertIsNone(self.server._server)
|
||||
|
||||
def test_pool(self):
|
||||
self.assertEqual(self.mock_pool, self.server.pool)
|
||||
|
||||
def test_is_serving(self):
|
||||
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
|
||||
self.assertEqual(FOO + BAR, self.server.is_serving())
|
||||
|
||||
@patch.object(server, 'ControlSession')
|
||||
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
|
||||
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
|
||||
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
|
||||
mock_reader, mock_writer = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
|
||||
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
|
||||
mock_client_handshake.assert_awaited_once_with()
|
||||
mock_listen.assert_awaited_once_with()
|
||||
|
||||
@patch.object(server.ControlServer, '_final_callback')
|
||||
async def test__serve_forever(self, mock__final_callback: MagicMock):
|
||||
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
|
||||
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
|
||||
with self.assertLogs(server.log, logging.DEBUG):
|
||||
self.assertIsNone(await self.server._serve_forever())
|
||||
mock_aenter.assert_awaited_once_with()
|
||||
mock_serve_forever.assert_awaited_once_with()
|
||||
mock__final_callback.assert_called_once_with()
|
||||
|
||||
mock_aenter.reset_mock()
|
||||
mock_serve_forever.reset_mock(side_effect=True)
|
||||
mock__final_callback.reset_mock()
|
||||
|
||||
self.assertIsNone(await self.server._serve_forever())
|
||||
mock_aenter.assert_awaited_once_with()
|
||||
mock_serve_forever.assert_awaited_once_with()
|
||||
mock__final_callback.assert_called_once_with()
|
||||
|
||||
@patch.object(server, 'create_task')
|
||||
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
|
||||
@patch.object(server.ControlServer, '_get_server_instance')
|
||||
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
|
||||
mock_create_task: MagicMock):
|
||||
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
|
||||
mock_create_task.return_value = expected_output = 12345
|
||||
output = await self.server.serve_forever()
|
||||
self.assertEqual(expected_output, output)
|
||||
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
|
||||
mock__serve_forever.assert_called_once_with()
|
||||
mock_create_task.assert_called_once_with(mock_awaitable)
|
||||
|
||||
|
||||
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = server.log.level
|
||||
server.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
server.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
|
||||
self.mock_base_init = self.base_init_patcher.start()
|
||||
self.mock_pool = MagicMock()
|
||||
self.path = '/tmp/asyncio_taskpool'
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.base_init_patcher.stop()
|
||||
|
||||
def test__client_class(self):
|
||||
self.assertEqual(UnixControlClient, self.server._client_class)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(Path(self.path), self.server._socket_path)
|
||||
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
|
||||
|
||||
@patch.object(server, 'start_unix_server')
|
||||
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
|
||||
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
|
||||
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
|
||||
args = [mock_callback]
|
||||
output = await self.server._get_server_instance(*args, **mock_kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
|
||||
|
||||
def test__final_callback(self):
|
||||
self.server._socket_path = MagicMock()
|
||||
self.assertIsNone(self.server._final_callback())
|
||||
self.server._socket_path.unlink.assert_called_once_with()
|
324
tests/test_session.py
Normal file
324
tests/test_session.py
Normal file
@ -0,0 +1,324 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.session` module.
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError, Namespace
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
from asyncio_taskpool import session
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER
|
||||
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = session.log.level
|
||||
session.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
session.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
|
||||
self.mock_client_class_name = FOO + BAR
|
||||
self.mock_server = MagicMock(pool=self.mock_pool,
|
||||
client_class_name=self.mock_client_class_name)
|
||||
self.mock_reader = MagicMock()
|
||||
self.mock_writer = MagicMock()
|
||||
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.mock_server, self.session._control_server)
|
||||
self.assertEqual(self.mock_pool, self.session._pool)
|
||||
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
|
||||
self.assertEqual(self.mock_reader, self.session._reader)
|
||||
self.assertEqual(self.mock_writer, self.session._writer)
|
||||
self.assertIsNone(self.session._parser)
|
||||
self.assertIsNone(self.session._subparsers)
|
||||
|
||||
def test__add_command(self):
|
||||
expected_output = 123456
|
||||
mock_add_parser = MagicMock(return_value=expected_output)
|
||||
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
|
||||
self.session._parser = MagicMock()
|
||||
name, prog, short_help, long_help = 'abc', None, 'short123', None
|
||||
kwargs = {'x': 1, 'y': 2}
|
||||
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
|
||||
parent=self.session._parser, **kwargs)
|
||||
|
||||
mock_add_parser.reset_mock()
|
||||
|
||||
prog, long_help = 'ffffff', 'so long, wow'
|
||||
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
|
||||
parent=self.session._parser, **kwargs)
|
||||
|
||||
mock_add_parser.reset_mock()
|
||||
|
||||
short_help = None
|
||||
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
|
||||
parent=self.session._parser, **kwargs)
|
||||
|
||||
@patch.object(session, 'get_first_doc_line')
|
||||
@patch.object(session.ControlSession, '_add_command')
|
||||
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
|
||||
self.assertIsNone(self.session._add_base_commands())
|
||||
mock__add_command.assert_called()
|
||||
mock_get_first_doc_line.assert_called()
|
||||
|
||||
mock__add_command.reset_mock()
|
||||
mock_get_first_doc_line.reset_mock()
|
||||
|
||||
self.assertIsNone(self.session._add_simple_commands())
|
||||
mock__add_command.assert_called()
|
||||
mock_get_first_doc_line.assert_called()
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.session._add_advanced_commands()
|
||||
|
||||
@patch.object(session.ControlSession, '_add_simple_commands')
|
||||
@patch.object(session.ControlSession, '_add_advanced_commands')
|
||||
@patch.object(session.ControlSession, '_add_base_commands')
|
||||
@patch.object(session, 'CommandParser')
|
||||
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
|
||||
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
|
||||
mock_command_parser_cls.return_value = mock_parser = MagicMock()
|
||||
self.session._pool = TaskPool()
|
||||
width = 1234
|
||||
expected_parser_kwargs = {
|
||||
'prog': '',
|
||||
SESSION_WRITER: self.mock_writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: width,
|
||||
}
|
||||
self.assertIsNone(self.session._init_parser(width))
|
||||
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||
mock__add_base_commands.assert_called_once_with()
|
||||
mock__add_advanced_commands.assert_called_once_with()
|
||||
mock__add_simple_commands.assert_not_called()
|
||||
|
||||
mock_command_parser_cls.reset_mock()
|
||||
mock_parser.add_subparsers.reset_mock()
|
||||
mock__add_base_commands.reset_mock()
|
||||
mock__add_advanced_commands.reset_mock()
|
||||
mock__add_simple_commands.reset_mock()
|
||||
|
||||
async def fake_coroutine(): pass
|
||||
|
||||
self.session._pool = SimpleTaskPool(fake_coroutine)
|
||||
self.assertIsNone(self.session._init_parser(width))
|
||||
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||
mock__add_base_commands.assert_called_once_with()
|
||||
mock__add_advanced_commands.assert_not_called()
|
||||
mock__add_simple_commands.assert_called_once_with()
|
||||
|
||||
mock_command_parser_cls.reset_mock()
|
||||
mock_parser.add_subparsers.reset_mock()
|
||||
mock__add_base_commands.reset_mock()
|
||||
mock__add_advanced_commands.reset_mock()
|
||||
mock__add_simple_commands.reset_mock()
|
||||
|
||||
class FakeTaskPool(BaseTaskPool):
|
||||
pass
|
||||
|
||||
self.session._pool = FakeTaskPool()
|
||||
with self.assertRaises(UnknownTaskPoolClass):
|
||||
self.session._init_parser(width)
|
||||
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||
mock__add_base_commands.assert_called_once_with()
|
||||
mock__add_advanced_commands.assert_not_called()
|
||||
mock__add_simple_commands.assert_not_called()
|
||||
|
||||
mock_command_parser_cls.reset_mock()
|
||||
mock_parser.add_subparsers.reset_mock()
|
||||
mock__add_base_commands.reset_mock()
|
||||
mock__add_advanced_commands.reset_mock()
|
||||
mock__add_simple_commands.reset_mock()
|
||||
|
||||
self.session._pool = MagicMock()
|
||||
with self.assertRaises(NotATaskPool):
|
||||
self.session._init_parser(width)
|
||||
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||
mock__add_base_commands.assert_called_once_with()
|
||||
mock__add_advanced_commands.assert_not_called()
|
||||
mock__add_simple_commands.assert_not_called()
|
||||
|
||||
@patch.object(session.ControlSession, '_init_parser')
|
||||
async def test_client_handshake(self, mock__init_parser: MagicMock):
|
||||
width = 5678
|
||||
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
||||
mock_read = AsyncMock(return_value=msg.encode())
|
||||
self.mock_reader.read = mock_read
|
||||
self.mock_writer.drain = AsyncMock()
|
||||
self.assertIsNone(await self.session.client_handshake())
|
||||
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
mock__init_parser.assert_called_once_with(width)
|
||||
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
|
||||
self.mock_writer.drain.assert_awaited_once_with()
|
||||
|
||||
@patch.object(session, 'return_or_exception')
|
||||
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
|
||||
self.mock_writer.write = MagicMock()
|
||||
mock_return_or_exception.return_value = None
|
||||
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
|
||||
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
|
||||
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
|
||||
self.mock_writer.write.assert_called_once_with(b"ok")
|
||||
|
||||
mock_return_or_exception.reset_mock()
|
||||
self.mock_writer.write.reset_mock()
|
||||
|
||||
mock_return_or_exception.return_value = output = MagicMock()
|
||||
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
|
||||
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
|
||||
self.mock_writer.write.assert_called_once_with(str(output).encode())
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
|
||||
self.assertIsNone(await self.session._cmd_name())
|
||||
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
|
||||
num = 12345
|
||||
kwargs = {session.NUM: num, FOO: BAR}
|
||||
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
|
||||
mock__write_function_output.assert_awaited_once_with(
|
||||
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
|
||||
)
|
||||
|
||||
mock__write_function_output.reset_mock()
|
||||
|
||||
kwargs.pop(session.NUM)
|
||||
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
|
||||
mock__write_function_output.assert_awaited_once_with(
|
||||
self.mock_pool.__class__.pool_size.fget, self.session._pool
|
||||
)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
|
||||
self.assertIsNone(await self.session._cmd_num_running())
|
||||
mock__write_function_output.assert_awaited_once_with(
|
||||
self.mock_pool.__class__.num_running.fget, self.session._pool
|
||||
)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
|
||||
num = 12345
|
||||
kwargs = {session.NUM: num, FOO: BAR}
|
||||
self.assertIsNone(await self.session._cmd_start(**kwargs))
|
||||
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
|
||||
num = 12345
|
||||
kwargs = {session.NUM: num, FOO: BAR}
|
||||
self.assertIsNone(await self.session._cmd_stop(**kwargs))
|
||||
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
|
||||
self.assertIsNone(await self.session._cmd_stop_all())
|
||||
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
|
||||
|
||||
@patch.object(session.ControlSession, '_write_function_output')
|
||||
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
|
||||
self.assertIsNone(await self.session._cmd_func_name())
|
||||
mock__write_function_output.assert_awaited_once_with(
|
||||
self.mock_pool.__class__.func_name.fget, self.session._pool
|
||||
)
|
||||
|
||||
async def test__execute_command(self):
|
||||
mock_method = AsyncMock()
|
||||
cmd = 'this-is-a-test'
|
||||
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
|
||||
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
|
||||
mock_method.assert_awaited_once_with(**kwargs)
|
||||
|
||||
@patch.object(session.ControlSession, '_execute_command')
|
||||
async def test__parse_command(self, mock__execute_command: AsyncMock):
|
||||
msg = 'asdf asd as a'
|
||||
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
|
||||
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
||||
self.mock_writer.write = MagicMock()
|
||||
self.assertIsNone(await self.session._parse_command(msg))
|
||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||
self.mock_writer.write.assert_not_called()
|
||||
mock__execute_command.assert_awaited_once_with(**kwargs)
|
||||
|
||||
mock__execute_command.reset_mock()
|
||||
mock_parse_args.reset_mock()
|
||||
|
||||
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
||||
self.assertIsNone(await self.session._parse_command(msg))
|
||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
||||
mock__execute_command.assert_not_awaited()
|
||||
|
||||
self.mock_writer.write.reset_mock()
|
||||
mock_parse_args.reset_mock()
|
||||
|
||||
mock_parse_args.side_effect = HelpRequested()
|
||||
self.assertIsNone(await self.session._parse_command(msg))
|
||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||
self.mock_writer.write.assert_not_called()
|
||||
mock__execute_command.assert_not_awaited()
|
||||
|
||||
@patch.object(session.ControlSession, '_parse_command')
|
||||
async def test_listen(self, mock__parse_command: AsyncMock):
|
||||
def make_reader_return_empty():
|
||||
self.mock_reader.read.return_value = b''
|
||||
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
||||
msg = "fascinating"
|
||||
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
|
||||
self.assertIsNone(await self.session.listen())
|
||||
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
|
||||
mock__parse_command.assert_awaited_once_with(msg)
|
||||
self.mock_writer.drain.assert_awaited_once_with()
|
||||
|
||||
self.mock_reader.read.reset_mock()
|
||||
mock__parse_command.reset_mock()
|
||||
self.mock_writer.drain.reset_mock()
|
||||
|
||||
self.mock_server.is_serving = MagicMock(return_value=False)
|
||||
self.assertIsNone(await self.session.listen())
|
||||
self.mock_reader.read.assert_not_awaited()
|
||||
mock__parse_command.assert_not_awaited()
|
||||
self.mock_writer.drain.assert_not_awaited()
|
134
tests/test_session_parser.py
Normal file
134
tests/test_session_parser.py
Normal file
@ -0,0 +1,134 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.session_parser` module.
|
||||
"""
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import session_parser
|
||||
from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO
|
||||
from asyncio_taskpool.exceptions import HelpRequested
|
||||
|
||||
|
||||
FOO = 'foo'
|
||||
|
||||
|
||||
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory')
|
||||
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
||||
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
||||
self.session_writer, self.terminal_width = MagicMock(), 420
|
||||
self.kwargs = {
|
||||
SESSION_WRITER: self.session_writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width,
|
||||
session_parser.FORMATTER_CLASS: FOO
|
||||
}
|
||||
self.parser = session_parser.CommandParser(**self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.help_formatter_factory_patcher.stop()
|
||||
|
||||
def test_help_formatter_factory(self):
|
||||
self.help_formatter_factory_patcher.stop()
|
||||
|
||||
class MockBaseClass(HelpFormatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
terminal_width = 123456789
|
||||
cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass)
|
||||
self.assertTrue(issubclass(cls, MockBaseClass))
|
||||
instance = cls('prog')
|
||||
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||
|
||||
cls = session_parser.CommandParser.help_formatter_factory(terminal_width)
|
||||
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
|
||||
instance = cls('prog')
|
||||
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||
|
||||
def test_init(self):
|
||||
self.assertIsInstance(self.parser, ArgumentParser)
|
||||
self.assertEqual(self.session_writer, self.parser._session_writer)
|
||||
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
||||
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
||||
self.assertFalse(getattr(self.parser, 'exit_on_error'))
|
||||
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
||||
|
||||
def test_session_writer(self):
|
||||
self.assertEqual(self.session_writer, self.parser.session_writer)
|
||||
|
||||
def test_terminal_width(self):
|
||||
self.assertEqual(self.terminal_width, self.parser.terminal_width)
|
||||
|
||||
def test__print_message(self):
|
||||
self.session_writer.write = MagicMock()
|
||||
self.assertIsNone(self.parser._print_message(''))
|
||||
self.session_writer.write.assert_not_called()
|
||||
msg = 'foo bar baz'
|
||||
self.assertIsNone(self.parser._print_message(msg))
|
||||
self.session_writer.write.assert_called_once_with(msg.encode())
|
||||
|
||||
@patch.object(session_parser.CommandParser, '_print_message')
|
||||
def test_exit(self, mock__print_message: MagicMock):
|
||||
self.assertIsNone(self.parser.exit(123, ''))
|
||||
mock__print_message.assert_not_called()
|
||||
msg = 'foo bar baz'
|
||||
self.assertIsNone(self.parser.exit(123, msg))
|
||||
mock__print_message.assert_called_once_with(msg)
|
||||
|
||||
@patch.object(session_parser.ArgumentParser, 'print_help')
|
||||
def test_print_help(self, mock_print_help: MagicMock):
|
||||
arg = MagicMock()
|
||||
with self.assertRaises(HelpRequested):
|
||||
self.parser.print_help(arg)
|
||||
mock_print_help.assert_called_once_with(arg)
|
||||
|
||||
def test_add_optional_num_argument(self):
|
||||
metavar = 'FOOBAR'
|
||||
action = self.parser.add_optional_num_argument(metavar=metavar)
|
||||
self.assertIsInstance(action, Action)
|
||||
self.assertEqual('?', action.nargs)
|
||||
self.assertEqual(1, action.default)
|
||||
self.assertEqual(int, action.type)
|
||||
self.assertEqual(metavar, action.metavar)
|
||||
num = 111
|
||||
kwargs = vars(self.parser.parse_args([f'{num}']))
|
||||
self.assertDictEqual({session_parser.NUM: num}, kwargs)
|
||||
|
||||
name = f'--{FOO}'
|
||||
nargs = '+'
|
||||
default = 1
|
||||
_type = float
|
||||
required = True
|
||||
dest = 'foo_bar'
|
||||
action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type,
|
||||
required=required, metavar=metavar, dest=dest)
|
||||
self.assertIsInstance(action, Action)
|
||||
self.assertEqual(nargs, action.nargs)
|
||||
self.assertEqual(default, action.default)
|
||||
self.assertEqual(_type, action.type)
|
||||
self.assertEqual(required, action.required)
|
||||
self.assertEqual(metavar, action.metavar)
|
||||
self.assertEqual(dest, action.dest)
|
||||
kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5']))
|
||||
self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)
|
@ -130,7 +130,7 @@ async def main() -> None:
|
||||
# **only** once there is room in the pool **and** no more than one of these last four tasks is running.
|
||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||
print("Calling `starmap`...")
|
||||
await pool.starmap(other_work, args_list, num_tasks=2)
|
||||
await pool.starmap(other_work, args_list, group_size=2)
|
||||
print("`starmap` returned")
|
||||
# Now we lock the pool, so that we can safely await all our tasks.
|
||||
pool.lock()
|
||||
|
Reference in New Issue
Block a user