Compare commits

...

3 Commits

12 changed files with 579 additions and 78 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.3.3 version = 0.3.5
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -25,54 +25,116 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path from pathlib import Path
from typing import Optional
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT from .types import ClientConnT, PathT
class ControlClient(ABC): class ControlClient(ABC):
"""
Abstract base class for a simple implementation of a task pool control client.
@abstractmethod Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
async def open_connection(self, **kwargs) -> ClientConnT: TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
raise NotImplementedError This is a minimal working implementation.
"""
@staticmethod @staticmethod
def client_info() -> dict: def client_info() -> dict:
"""Returns a dictionary of client information relevant for the handshake with the server."""
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns} return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
@abstractmethod
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
as keyword-arguments.
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
`None` and `None`, if it failed to establish the defined connection.
"""
raise NotImplementedError
def __init__(self, **conn_kwargs) -> None: def __init__(self, **conn_kwargs) -> None:
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
self._conn_kwargs = conn_kwargs self._conn_kwargs = conn_kwargs
self._connected: bool = False self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None: async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Performs the first interaction with the server providing it with the necessary client information.
Upon completion, the server's info is printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
self._connected = True self._connected = True
writer.write(json.dumps(self.client_info()).encode()) writer.write(json.dumps(self.client_info()).encode())
await writer.drain() await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: def _get_command(self, writer: StreamWriter) -> Optional[str]:
"""
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
Args:
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
"""
try: try:
msg = input("> ").strip().lower() msg = input("> ").strip().lower()
except EOFError: except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
msg = CLIENT_EXIT msg = CLIENT_EXIT
except KeyboardInterrupt: except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print() print()
return return
if msg == CLIENT_EXIT: if msg == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
return msg
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
cmd = self._get_command(writer)
if cmd is None:
return
try: try:
writer.write(msg.encode()) # Send the command to the server.
writer.write(cmd.encode())
await writer.drain() await writer.drain()
except ConnectionError as e: except ConnectionError as e:
self._connected = False self._connected = False
print(e, file=sys.stderr) print(e, file=sys.stderr)
return return
# Await the server's response, then print it.
print((await reader.read(SESSION_MSG_BYTES)).decode()) print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self): async def start(self) -> None:
reader, writer = await self.open_connection(**self._conn_kwargs) """
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
If the connection can not be established, an error message is printed to `stderr` and the method returns.
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
disconnected-message.
"""
reader, writer = await self._open_connection(**self._conn_kwargs)
if reader is None: if reader is None:
print("Failed to connect.", file=sys.stderr) print("Failed to connect.", file=sys.stderr)
return return
@ -83,11 +145,24 @@ class ControlClient(ABC):
class UnixControlClient(ControlClient): class UnixControlClient(ControlClient):
def __init__(self, **conn_kwargs) -> None: """Task pool control client that expects a unix socket to be exposed by the control server."""
self._socket_path = Path(conn_kwargs.pop('path'))
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
"""
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
"""
self._socket_path = Path(socket_path)
super().__init__(**conn_kwargs) super().__init__(**conn_kwargs)
async def open_connection(self, **kwargs) -> ClientConnT: async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_unix_connection` function.
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
otherwise, the stream-reader and -writer tuple is returned.
"""
try: try:
return await open_unix_connection(self._socket_path, **kwargs) return await open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError: except FileNotFoundError:

View File

@ -24,7 +24,7 @@ PACKAGE_NAME = 'asyncio_taskpool'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100
SESSION_PARSER_WRITER = 'session_writer' SESSION_WRITER = 'session_writer'
class CLIENT_INFO: class CLIENT_INFO:

View File

@ -319,8 +319,8 @@ class BaseTaskPool:
""" """
Cancels all tasks still running within the pool. Cancels all tasks still running within the pool.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
than the number of arguments from `args_iter`. than the number of arguments from `args_iter`.
In this case, those already running will be cancelled, while the following will **never even start**. In this case, those already running will be cancelled, while the following will **never even start**.
@ -355,8 +355,8 @@ class BaseTaskPool:
The `lock()` method must have been called prior to this. The `lock()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
than the number of arguments from `args_iter`. than the number of arguments from `args_iter`.
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
blocking this method. Otherwise it will wait until they all have started. blocking this method. Otherwise it will wait until they all have started.
@ -552,10 +552,10 @@ class TaskPool(BaseTaskPool):
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue: def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
""" """
Helper function for `_map()`. Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`. Takes the iterable of function arguments `args_iter` and adds up to `group_size` to a new `asyncio.Queue`.
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned. The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer` If the iterable contains less than `group_size` elements, nothing else happens; otherwise the `_queue_producer`
is started as a separate task with the arguments queue and and iterator of the remaining arguments. is started as a separate task with the arguments queue and and iterator of the remaining arguments.
Args: Args:
@ -567,7 +567,7 @@ class TaskPool(BaseTaskPool):
Returns: Returns:
The newly created and filled arguments queue for spawning new tasks. The newly created and filled arguments queue for spawning new tasks.
""" """
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run # Setting the `maxsize` of the queue to `group_size` will ensure that no more than `group_size` tasks will run
# concurrently because the size of the queue is what will determine the number of immediately started tasks in # concurrently because the size of the queue is what will determine the number of immediately started tasks in
# the `_map()` method and each of those will only ever start (at most) one other task upon ending. # the `_map()` method and each of those will only ever start (at most) one other task upon ending.
args_queue = Queue(maxsize=num_tasks) args_queue = Queue(maxsize=num_tasks)
@ -575,12 +575,12 @@ class TaskPool(BaseTaskPool):
args_iter = iter(args_iter) args_iter = iter(args_iter)
try: try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full). # tasks, which will be at most `group_size` (meaning the queue will be full).
for i in range(num_tasks): for i in range(num_tasks):
args_queue.put_nowait(next(args_iter)) args_queue.put_nowait(next(args_iter))
except StopIteration: except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the # If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue. # specified `group_size`. Still, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue. # The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
pass pass
else: else:
@ -591,17 +591,27 @@ class TaskPool(BaseTaskPool):
create_task(self._queue_producer(args_queue, args_iter)) create_task(self._queue_producer(args_queue, args_iter))
return args_queue return args_queue
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, group_size: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches. Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
and referring to "group size" rather than chunk/batch size.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `args_iter`.
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable. The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
any given moment in time. Assuming the number of elements produced by `args_iter` is greater than `group_size`,
this method will block **only** until the first `group_size` tasks have been **started**, before returning.
(If the number of elements from `args_iter` is smaller than `group_size`, this method will return as soon as
all of them have been started.)
As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
equal to `group_size` (except for when `args_iter` is exhausted of course).
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
This method sets up an internal arguments queue which is continuously filled while consuming the `args_iter`.
Args: Args:
func: func:
@ -610,8 +620,8 @@ class TaskPool(BaseTaskPool):
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task. The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional): arg_stars (optional):
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2. Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
num_tasks (optional): group_size (optional):
The maximum number of the new tasks to run concurrently. The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
end_callback (optional): end_callback (optional):
A callback to execute after a task has ended. A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
@ -624,7 +634,7 @@ class TaskPool(BaseTaskPool):
""" """
if not self._locked: if not self._locked:
raise exceptions.PoolIsLocked("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
args_queue = self._set_up_args_queue(args_iter, num_tasks) args_queue = self._set_up_args_queue(args_iter, group_size)
# We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the # We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
# `_queue_callback` triggered by one or more of them. # `_queue_callback` triggered by one or more of them.
# This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more # This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more
@ -638,29 +648,34 @@ class TaskPool(BaseTaskPool):
# Now the callbacks can immediately trigger more tasks. # Now the callbacks can immediately trigger more tasks.
first_batch_started.set() first_batch_started.set()
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method. An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches. Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable. Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
Once the first batch of tasks has started to run, this method returns. The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool) any given moment in time. Assuming the number of elements produced by `arg_iter` is greater than `group_size`,
consuming the next element from the arguments iterable. this method will block **only** until the first `group_size` tasks have been **started**, before returning.
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number (If the number of elements from `arg_iter` is smaller than `group_size`, this method will return as soon as
of tasks from this call concurrently running within the pool. all of them have been started.)
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
equal to `group_size` (except for when `arg_iter` is exhausted of course).
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
Args: Args:
func: func:
The coroutine function to use for spawning the new tasks within the task pool. The coroutine function to use for spawning the new tasks within the task pool.
arg_iter: arg_iter:
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task. The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
num_tasks (optional): group_size (optional):
The maximum number of the new tasks to run concurrently. The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
end_callback (optional): end_callback (optional):
A callback to execute after a task has ended. A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
@ -672,27 +687,27 @@ class TaskPool(BaseTaskPool):
`PoolIsLocked` if the pool has been locked. `PoolIsLocked` if the pool has been locked.
`NotCoroutine` if `func` is not a coroutine function. `NotCoroutine` if `func` is not a coroutine function.
""" """
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks, await self._map(func, arg_iter, arg_stars=0, group_size=group_size,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
positional arguments to the function. positional arguments to the function.
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`. Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
""" """
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, await self._map(func, args_iter, arg_stars=1, group_size=group_size,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
keyword-arguments to the function. keyword-arguments to the function.
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`. Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
""" """
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, await self._map(func, kwargs_iter, arg_stars=2, group_size=group_size,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)

View File

@ -25,7 +25,7 @@ from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from .helpers import get_first_doc_line, return_or_exception, tasks_str from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
@ -166,7 +166,7 @@ class ControlSession:
""" """
parser_kwargs = { parser_kwargs = {
'prog': '', 'prog': '',
SESSION_PARSER_WRITER: self._writer, SESSION_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
} }
self._parser = CommandParser(**parser_kwargs) self._parser = CommandParser(**parser_kwargs)

View File

@ -1,8 +1,29 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the `CommandParser` class used in a control server session.
"""
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
from asyncio.streams import StreamWriter from asyncio.streams import StreamWriter
from typing import Type, TypeVar from typing import Type, TypeVar
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO from .constants import SESSION_WRITER, CLIENT_INFO
from .exceptions import HelpRequested from .exceptions import HelpRequested
@ -12,8 +33,33 @@ NUM = 'num'
class CommandParser(ArgumentParser): class CommandParser(ArgumentParser):
"""
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client.
Finally, it offers some convenience methods and makes use of custom exceptions.
"""
@staticmethod @staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
"""
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
as convenient, when making use of sub-parsers.
Args:
terminal_width:
The number of columns of the terminal to which to adjust help formatting.
base_cls (optional):
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
Returns:
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
"""
if base_cls is None: if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter base_cls = ArgumentDefaultsHelpFormatter
@ -23,35 +69,56 @@ class CommandParser(ArgumentParser):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
return ClientHelpFormatter return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None: def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None) """
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER) Sets additional internal attributes depending on whether a parent-parser was defined.
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
Args:
parent (optional):
An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method
of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present,
the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`.
**kwargs(optional):
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
is empty.
"""
self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER)
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH) self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs.setdefault('exit_on_error', False) kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs) super().__init__(**kwargs)
@property @property
def stream_writer(self) -> StreamWriter: def session_writer(self) -> StreamWriter:
return self._stream_writer """Returns the predefined stream writer object of the control session."""
return self._session_writer
@property @property
def terminal_width(self) -> int: def terminal_width(self) -> int:
"""Returns the predefined terminal width."""
return self._terminal_width return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None: def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
if message: if message:
self.stream_writer.write(message.encode()) self._session_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None: def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked."""
if message: if message:
self._print_message(message) self._print_message(message)
def print_help(self, file=None) -> None: def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().print_help(file) super().print_help(file)
raise HelpRequested raise HelpRequested
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action: def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
"""Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified."""
if not name_or_flags: if not name_or_flags:
name_or_flags = (NUM, ) name_or_flags = (NUM, )
kwargs.setdefault('nargs', '?') kwargs.setdefault('nargs', '?')

View File

@ -20,6 +20,7 @@ Custom type definitions used in various modules.
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
@ -36,3 +37,5 @@ CancelCallbackT = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
PathT = Union[Path, str]

207
tests/test_client.py Normal file
View File

@ -0,0 +1,207 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.client` module.
"""
import json
import shutil
import sys
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import client
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
FOO, BAR = 'foo', 'bar'
class ControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
self.print_patcher = patch.object(client, 'print')
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_print = self.print_patcher.start()
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.ControlClient(**self.kwargs)
self.mock_read = AsyncMock(return_value=FOO.encode())
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
self.mock_reader = MagicMock(read=self.mock_read)
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
def tearDown(self) -> None:
self.abstract_patcher.stop()
self.print_patcher.stop()
def test_client_info(self):
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
client.ControlClient.client_info())
async def test_abstract(self):
with self.assertRaises(NotImplementedError):
await self.client._open_connection(**self.kwargs)
def test_init(self):
self.assertEqual(self.kwargs, self.client._conn_kwargs)
self.assertFalse(self.client._connected)
@patch.object(client.ControlClient, 'client_info')
async def test__server_handshake(self, mock_client_info: MagicMock):
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected)
mock_client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
@patch.object(client, 'input')
def test__get_command(self, mock_input: MagicMock):
self.client._connected = True
mock_input.return_value = ' ' + FOO.upper() + ' '
mock_close = MagicMock()
mock_writer = MagicMock(close=mock_close)
output = self.client._get_command(mock_writer)
self.assertEqual(FOO, output)
mock_input.assert_called_once()
mock_close.assert_not_called()
self.assertTrue(self.client._connected)
mock_input.reset_mock()
mock_input.side_effect = KeyboardInterrupt
self.assertIsNone(self.client._get_command(mock_writer))
mock_input.assert_called_once()
mock_close.assert_not_called()
self.assertTrue(self.client._connected)
mock_input.reset_mock()
mock_input.side_effect = EOFError
self.assertIsNone(self.client._get_command(mock_writer))
mock_input.assert_called_once()
mock_close.assert_called_once()
self.assertFalse(self.client._connected)
@patch.object(client.ControlClient, '_get_command')
async def test__interact(self, mock__get_command: MagicMock):
self.client._connected = True
mock__get_command.return_value = None
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_not_called()
self.mock_drain.assert_not_awaited()
self.mock_read.assert_not_awaited()
self.mock_print.assert_not_called()
self.assertTrue(self.client._connected)
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
self.mock_drain.side_effect = err = ConnectionError()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_not_awaited()
self.mock_print.assert_called_once_with(err, file=sys.stderr)
self.assertFalse(self.client._connected)
self.client._connected = True
self.mock_write.reset_mock()
self.mock_drain.reset_mock(side_effect=True)
self.mock_print.reset_mock()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with(FOO)
self.assertTrue(self.client._connected)
@patch.object(client.ControlClient, '_interact')
@patch.object(client.ControlClient, '_server_handshake')
@patch.object(client.ControlClient, '_open_connection')
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
mock__interact: AsyncMock):
mock__open_connection.return_value = None, None
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_not_awaited()
mock__interact.assert_not_awaited()
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
mock__open_connection.reset_mock()
self.mock_print.reset_mock()
mock__open_connection.return_value = self.mock_reader, self.mock_writer
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
mock__interact.assert_not_awaited()
self.mock_print.assert_called_once_with("Disconnected from control server.")
mock__open_connection.reset_mock()
mock__server_handshake.reset_mock()
self.mock_print.reset_mock()
self.client._connected = True
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
mock__interact.side_effect = disconnect
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
self.mock_print.assert_called_once_with("Disconnected from control server.")
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.path = '/tmp/asyncio_taskpool'
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test_init(self):
self.assertEqual(Path(self.path), self.client._socket_path)
self.mock_base_init.assert_called_once_with(**self.kwargs)
@patch.object(client, 'print')
@patch.object(client, 'open_unix_connection')
async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock):
mock_open_unix_connection.return_value = expected_output = 'something'
kwargs = {'a': 1, 'b': 2}
output = await self.client._open_connection(**kwargs)
self.assertEqual(expected_output, output)
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
mock_print.assert_not_called()
mock_open_unix_connection.reset_mock()
mock_open_unix_connection.side_effect = FileNotFoundError
output1, output2 = await self.client._open_connection(**kwargs)
self.assertIsNone(output1)
self.assertIsNone(output2)
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)

View File

@ -558,19 +558,19 @@ class TaskPoolTestCase(CommonTestCase):
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set) mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
mock_func, stars = MagicMock(), 3 mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2 args_iter, group_size = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._locked = False self.task_pool._locked = False
with self.assertRaises(exceptions.PoolIsLocked): with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb) await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)
mock__set_up_args_queue.assert_not_called() mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited() mock__queue_consumer.assert_not_awaited()
mock_flag_set.assert_not_called() mock_flag_set.assert_not_called()
self.task_pool._locked = True self.task_pool._locked = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks) mock__set_up_args_queue.assert_called_once_with(args_iter, group_size)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars, mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)]) end_callback=end_cb, cancel_callback=cancel_cb)])
mock_flag_set.assert_called_once_with() mock_flag_set.assert_called_once_with()
@ -578,28 +578,28 @@ class TaskPoolTestCase(CommonTestCase):
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock): async def test_map(self, mock__map: AsyncMock):
mock_func = MagicMock() mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2 arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks, mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock): async def test_starmap(self, mock__map: AsyncMock):
mock_func = MagicMock() mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2 args_iter, group_size = ([FOO], [BAR]), 2
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks, mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock): async def test_doublestarmap(self, mock__map: AsyncMock):
mock_func = MagicMock() mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2 kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)

View File

@ -25,7 +25,7 @@ from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session from asyncio_taskpool import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
@ -119,7 +119,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
width = 1234 width = 1234
expected_parser_kwargs = { expected_parser_kwargs = {
'prog': '', 'prog': '',
SESSION_PARSER_WRITER: self.mock_writer, SESSION_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width, CLIENT_INFO.TERMINAL_WIDTH: width,
} }
self.assertIsNone(self.session._init_parser(width)) self.assertIsNone(self.session._init_parser(width))

View File

@ -0,0 +1,134 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session_parser` module.
"""
from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch
from asyncio_taskpool import session_parser
from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO
from asyncio_taskpool.exceptions import HelpRequested
FOO = 'foo'
class ControlServerTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.session_writer, self.terminal_width = MagicMock(), 420
self.kwargs = {
SESSION_WRITER: self.session_writer,
CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width,
session_parser.FORMATTER_CLASS: FOO
}
self.parser = session_parser.CommandParser(**self.kwargs)
def tearDown(self) -> None:
self.help_formatter_factory_patcher.stop()
def test_help_formatter_factory(self):
self.help_formatter_factory_patcher.stop()
class MockBaseClass(HelpFormatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
terminal_width = 123456789
cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass)
self.assertTrue(issubclass(cls, MockBaseClass))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
cls = session_parser.CommandParser.help_formatter_factory(terminal_width)
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.session_writer, self.parser._session_writer)
self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
def test_session_writer(self):
self.assertEqual(self.session_writer, self.parser.session_writer)
def test_terminal_width(self):
self.assertEqual(self.terminal_width, self.parser.terminal_width)
def test__print_message(self):
self.session_writer.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))
self.session_writer.write.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg))
self.session_writer.write.assert_called_once_with(msg.encode())
@patch.object(session_parser.CommandParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock):
self.assertIsNone(self.parser.exit(123, ''))
mock__print_message.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser.exit(123, msg))
mock__print_message.assert_called_once_with(msg)
@patch.object(session_parser.ArgumentParser, 'print_help')
def test_print_help(self, mock_print_help: MagicMock):
arg = MagicMock()
with self.assertRaises(HelpRequested):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
def test_add_optional_num_argument(self):
metavar = 'FOOBAR'
action = self.parser.add_optional_num_argument(metavar=metavar)
self.assertIsInstance(action, Action)
self.assertEqual('?', action.nargs)
self.assertEqual(1, action.default)
self.assertEqual(int, action.type)
self.assertEqual(metavar, action.metavar)
num = 111
kwargs = vars(self.parser.parse_args([f'{num}']))
self.assertDictEqual({session_parser.NUM: num}, kwargs)
name = f'--{FOO}'
nargs = '+'
default = 1
_type = float
required = True
dest = 'foo_bar'
action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type,
required=required, metavar=metavar, dest=dest)
self.assertIsInstance(action, Action)
self.assertEqual(nargs, action.nargs)
self.assertEqual(default, action.default)
self.assertEqual(_type, action.type)
self.assertEqual(required, action.required)
self.assertEqual(metavar, action.metavar)
self.assertEqual(dest, action.dest)
kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5']))
self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)

View File

@ -130,7 +130,7 @@ async def main() -> None:
# **only** once there is room in the pool **and** no more than one of these last four tasks is running. # **only** once there is room in the pool **and** no more than one of these last four tasks is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
print("Calling `starmap`...") print("Calling `starmap`...")
await pool.starmap(other_work, args_list, num_tasks=2) await pool.starmap(other_work, args_list, group_size=2)
print("`starmap` returned") print("`starmap` returned")
# Now we lock the pool, so that we can safely await all our tasks. # Now we lock the pool, so that we can safely await all our tasks.
pool.lock() pool.lock()