generated from daniil-berg/boilerplate-py
Compare commits
3 Commits
96d01e7259
...
4994135062
Author | SHA1 | Date | |
---|---|---|---|
4994135062 | |||
d0c0177681 | |||
bac7b32342 |
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.3.3
|
||||
version = 0.3.5
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -25,54 +25,116 @@ import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||
from .types import ClientConnT
|
||||
from .types import ClientConnT, PathT
|
||||
|
||||
|
||||
class ControlClient(ABC):
|
||||
"""
|
||||
Abstract base class for a simple implementation of a task pool control client.
|
||||
|
||||
@abstractmethod
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
raise NotImplementedError
|
||||
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
|
||||
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
|
||||
This is a minimal working implementation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def client_info() -> dict:
|
||||
"""Returns a dictionary of client information relevant for the handshake with the server."""
|
||||
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||
|
||||
@abstractmethod
|
||||
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||
"""
|
||||
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
|
||||
|
||||
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
|
||||
as keyword-arguments.
|
||||
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
|
||||
`None` and `None`, if it failed to establish the defined connection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
|
||||
self._conn_kwargs = conn_kwargs
|
||||
self._connected: bool = False
|
||||
|
||||
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
Performs the first interaction with the server providing it with the necessary client information.
|
||||
|
||||
Upon completion, the server's info is printed.
|
||||
|
||||
Args:
|
||||
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
"""
|
||||
self._connected = True
|
||||
writer.write(json.dumps(self.client_info()).encode())
|
||||
await writer.drain()
|
||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||
|
||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
def _get_command(self, writer: StreamWriter) -> Optional[str]:
|
||||
"""
|
||||
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
|
||||
|
||||
Args:
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
|
||||
Returns:
|
||||
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
||||
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
||||
"""
|
||||
try:
|
||||
msg = input("> ").strip().lower()
|
||||
except EOFError:
|
||||
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
|
||||
msg = CLIENT_EXIT
|
||||
except KeyboardInterrupt:
|
||||
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||
print()
|
||||
return
|
||||
if msg == CLIENT_EXIT:
|
||||
writer.close()
|
||||
self._connected = False
|
||||
return
|
||||
return msg
|
||||
|
||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
"""
|
||||
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
|
||||
|
||||
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
|
||||
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
|
||||
|
||||
Args:
|
||||
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
"""
|
||||
cmd = self._get_command(writer)
|
||||
if cmd is None:
|
||||
return
|
||||
try:
|
||||
writer.write(msg.encode())
|
||||
# Send the command to the server.
|
||||
writer.write(cmd.encode())
|
||||
await writer.drain()
|
||||
except ConnectionError as e:
|
||||
self._connected = False
|
||||
print(e, file=sys.stderr)
|
||||
return
|
||||
# Await the server's response, then print it.
|
||||
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
||||
|
||||
async def start(self):
|
||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
|
||||
|
||||
If the connection can not be established, an error message is printed to `stderr` and the method returns.
|
||||
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
|
||||
disconnected-message.
|
||||
"""
|
||||
reader, writer = await self._open_connection(**self._conn_kwargs)
|
||||
if reader is None:
|
||||
print("Failed to connect.", file=sys.stderr)
|
||||
return
|
||||
@ -83,11 +145,24 @@ class ControlClient(ABC):
|
||||
|
||||
|
||||
class UnixControlClient(ControlClient):
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
self._socket_path = Path(conn_kwargs.pop('path'))
|
||||
"""Task pool control client that expects a unix socket to be exposed by the control server."""
|
||||
|
||||
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
|
||||
"""
|
||||
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
|
||||
|
||||
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
|
||||
"""
|
||||
self._socket_path = Path(socket_path)
|
||||
super().__init__(**conn_kwargs)
|
||||
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||
"""
|
||||
Wrapper around the `asyncio.open_unix_connection` function.
|
||||
|
||||
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
|
||||
otherwise, the stream-reader and -writer tuple is returned.
|
||||
"""
|
||||
try:
|
||||
return await open_unix_connection(self._socket_path, **kwargs)
|
||||
except FileNotFoundError:
|
||||
|
@ -24,7 +24,7 @@ PACKAGE_NAME = 'asyncio_taskpool'
|
||||
CLIENT_EXIT = 'exit'
|
||||
|
||||
SESSION_MSG_BYTES = 1024 * 100
|
||||
SESSION_PARSER_WRITER = 'session_writer'
|
||||
SESSION_WRITER = 'session_writer'
|
||||
|
||||
|
||||
class CLIENT_INFO:
|
||||
|
@ -319,8 +319,8 @@ class BaseTaskPool:
|
||||
"""
|
||||
Cancels all tasks still running within the pool.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||
|
||||
@ -355,8 +355,8 @@ class BaseTaskPool:
|
||||
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||
blocking this method. Otherwise it will wait until they all have started.
|
||||
@ -552,10 +552,10 @@ class TaskPool(BaseTaskPool):
|
||||
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
|
||||
"""
|
||||
Helper function for `_map()`.
|
||||
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
|
||||
Takes the iterable of function arguments `args_iter` and adds up to `group_size` to a new `asyncio.Queue`.
|
||||
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
|
||||
|
||||
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
|
||||
If the iterable contains less than `group_size` elements, nothing else happens; otherwise the `_queue_producer`
|
||||
is started as a separate task with the arguments queue and and iterator of the remaining arguments.
|
||||
|
||||
Args:
|
||||
@ -567,7 +567,7 @@ class TaskPool(BaseTaskPool):
|
||||
Returns:
|
||||
The newly created and filled arguments queue for spawning new tasks.
|
||||
"""
|
||||
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
|
||||
# Setting the `maxsize` of the queue to `group_size` will ensure that no more than `group_size` tasks will run
|
||||
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
|
||||
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
|
||||
args_queue = Queue(maxsize=num_tasks)
|
||||
@ -575,12 +575,12 @@ class TaskPool(BaseTaskPool):
|
||||
args_iter = iter(args_iter)
|
||||
try:
|
||||
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
|
||||
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
|
||||
# tasks, which will be at most `group_size` (meaning the queue will be full).
|
||||
for i in range(num_tasks):
|
||||
args_queue.put_nowait(next(args_iter))
|
||||
except StopIteration:
|
||||
# If we get here, this means that the number of elements in the arguments iterator was less than the
|
||||
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
|
||||
# specified `group_size`. Still, the number of tasks to start immediately will be the size of the queue.
|
||||
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
|
||||
pass
|
||||
else:
|
||||
@ -591,17 +591,27 @@ class TaskPool(BaseTaskPool):
|
||||
create_task(self._queue_producer(args_queue, args_iter))
|
||||
return args_queue
|
||||
|
||||
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
|
||||
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
|
||||
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
|
||||
and referring to "group size" rather than chunk/batch size.
|
||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
|
||||
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `args_iter`.
|
||||
|
||||
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. Assuming the number of elements produced by `args_iter` is greater than `group_size`,
|
||||
this method will block **only** until the first `group_size` tasks have been **started**, before returning.
|
||||
(If the number of elements from `args_iter` is smaller than `group_size`, this method will return as soon as
|
||||
all of them have been started.)
|
||||
|
||||
As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
|
||||
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
|
||||
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
|
||||
equal to `group_size` (except for when `args_iter` is exhausted of course).
|
||||
|
||||
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
|
||||
This method sets up an internal arguments queue which is continuously filled while consuming the `args_iter`.
|
||||
|
||||
Args:
|
||||
func:
|
||||
@ -610,8 +620,8 @@ class TaskPool(BaseTaskPool):
|
||||
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
|
||||
arg_stars (optional):
|
||||
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
|
||||
num_tasks (optional):
|
||||
The maximum number of the new tasks to run concurrently.
|
||||
group_size (optional):
|
||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -624,7 +634,7 @@ class TaskPool(BaseTaskPool):
|
||||
"""
|
||||
if not self._locked:
|
||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||
args_queue = self._set_up_args_queue(args_iter, num_tasks)
|
||||
args_queue = self._set_up_args_queue(args_iter, group_size)
|
||||
# We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
|
||||
# `_queue_callback` triggered by one or more of them.
|
||||
# This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more
|
||||
@ -638,29 +648,34 @@ class TaskPool(BaseTaskPool):
|
||||
# Now the callbacks can immediately trigger more tasks.
|
||||
first_batch_started.set()
|
||||
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||
|
||||
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
|
||||
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable.
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
||||
|
||||
Once the first batch of tasks has started to run, this method returns.
|
||||
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool)
|
||||
consuming the next element from the arguments iterable.
|
||||
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number
|
||||
of tasks from this call concurrently running within the pool.
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. Assuming the number of elements produced by `arg_iter` is greater than `group_size`,
|
||||
this method will block **only** until the first `group_size` tasks have been **started**, before returning.
|
||||
(If the number of elements from `arg_iter` is smaller than `group_size`, this method will return as soon as
|
||||
all of them have been started.)
|
||||
|
||||
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
|
||||
the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
|
||||
limit, this ensures that the number of tasks running concurrently as a result of this method call is always
|
||||
equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
||||
|
||||
Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
|
||||
|
||||
Args:
|
||||
func:
|
||||
The coroutine function to use for spawning the new tasks within the task pool.
|
||||
arg_iter:
|
||||
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
|
||||
num_tasks (optional):
|
||||
The maximum number of the new tasks to run concurrently.
|
||||
group_size (optional):
|
||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -672,27 +687,27 @@ class TaskPool(BaseTaskPool):
|
||||
`PoolIsLocked` if the pool has been locked.
|
||||
`NotCoroutine` if `func` is not a coroutine function.
|
||||
"""
|
||||
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
await self._map(func, arg_iter, arg_stars=0, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
|
||||
positional arguments to the function.
|
||||
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`.
|
||||
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
||||
"""
|
||||
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
await self._map(func, args_iter, arg_stars=1, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
|
||||
keyword-arguments to the function.
|
||||
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`.
|
||||
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
||||
"""
|
||||
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
await self._map(func, kwargs_iter, arg_stars=2, group_size=group_size,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from argparse import ArgumentError, HelpFormatter
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
||||
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||
@ -166,7 +166,7 @@ class ControlSession:
|
||||
"""
|
||||
parser_kwargs = {
|
||||
'prog': '',
|
||||
SESSION_PARSER_WRITER: self._writer,
|
||||
SESSION_WRITER: self._writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
|
||||
}
|
||||
self._parser = CommandParser(**parser_kwargs)
|
||||
|
@ -1,8 +1,29 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
This module contains the the definition of the `CommandParser` class used in a control server session.
|
||||
"""
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
|
||||
from asyncio.streams import StreamWriter
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from .constants import SESSION_PARSER_WRITER, CLIENT_INFO
|
||||
from .constants import SESSION_WRITER, CLIENT_INFO
|
||||
from .exceptions import HelpRequested
|
||||
|
||||
|
||||
@ -12,8 +33,33 @@ NUM = 'num'
|
||||
|
||||
|
||||
class CommandParser(ArgumentParser):
|
||||
"""
|
||||
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
|
||||
|
||||
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
|
||||
instance passed to it during initialization.
|
||||
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
||||
connected client.
|
||||
Finally, it offers some convenience methods and makes use of custom exceptions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
|
||||
"""
|
||||
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
|
||||
|
||||
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
|
||||
as convenient, when making use of sub-parsers.
|
||||
|
||||
Args:
|
||||
terminal_width:
|
||||
The number of columns of the terminal to which to adjust help formatting.
|
||||
base_cls (optional):
|
||||
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
|
||||
|
||||
Returns:
|
||||
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
|
||||
"""
|
||||
if base_cls is None:
|
||||
base_cls = ArgumentDefaultsHelpFormatter
|
||||
|
||||
@ -23,35 +69,56 @@ class CommandParser(ArgumentParser):
|
||||
super().__init__(*args, **kwargs)
|
||||
return ClientHelpFormatter
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
parent: CommandParser = kwargs.pop('parent', None)
|
||||
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop(SESSION_PARSER_WRITER)
|
||||
def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None:
|
||||
"""
|
||||
Sets additional internal attributes depending on whether a parent-parser was defined.
|
||||
|
||||
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
|
||||
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
||||
|
||||
Args:
|
||||
parent (optional):
|
||||
An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method
|
||||
of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present,
|
||||
the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`.
|
||||
**kwargs(optional):
|
||||
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
|
||||
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
|
||||
is empty.
|
||||
"""
|
||||
self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER)
|
||||
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
|
||||
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
|
||||
kwargs.setdefault('exit_on_error', False)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def stream_writer(self) -> StreamWriter:
|
||||
return self._stream_writer
|
||||
def session_writer(self) -> StreamWriter:
|
||||
"""Returns the predefined stream writer object of the control session."""
|
||||
return self._session_writer
|
||||
|
||||
@property
|
||||
def terminal_width(self) -> int:
|
||||
"""Returns the predefined terminal width."""
|
||||
return self._terminal_width
|
||||
|
||||
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
|
||||
if message:
|
||||
self.stream_writer.write(message.encode())
|
||||
self._session_writer.write(message.encode())
|
||||
|
||||
def exit(self, status: int = 0, message: str = None) -> None:
|
||||
"""This is overridden to prevent system exit to be invoked."""
|
||||
if message:
|
||||
self._print_message(message)
|
||||
|
||||
def print_help(self, file=None) -> None:
|
||||
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
|
||||
super().print_help(file)
|
||||
raise HelpRequested
|
||||
|
||||
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
|
||||
"""Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified."""
|
||||
if not name_or_flags:
|
||||
name_or_flags = (NUM, )
|
||||
kwargs.setdefault('nargs', '?')
|
||||
|
@ -20,6 +20,7 @@ Custom type definitions used in various modules.
|
||||
|
||||
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
||||
|
||||
|
||||
@ -36,3 +37,5 @@ CancelCallbackT = Callable
|
||||
|
||||
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||
|
||||
PathT = Union[Path, str]
|
||||
|
207
tests/test_client.py
Normal file
207
tests/test_client.py
Normal file
@ -0,0 +1,207 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.client` module.
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import client
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
|
||||
self.print_patcher = patch.object(client, 'print')
|
||||
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||
self.mock_print = self.print_patcher.start()
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.client = client.ControlClient(**self.kwargs)
|
||||
|
||||
self.mock_read = AsyncMock(return_value=FOO.encode())
|
||||
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
|
||||
self.mock_reader = MagicMock(read=self.mock_read)
|
||||
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.abstract_patcher.stop()
|
||||
self.print_patcher.stop()
|
||||
|
||||
def test_client_info(self):
|
||||
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
|
||||
client.ControlClient.client_info())
|
||||
|
||||
async def test_abstract(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
await self.client._open_connection(**self.kwargs)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.kwargs, self.client._conn_kwargs)
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, 'client_info')
|
||||
async def test__server_handshake(self, mock_client_info: MagicMock):
|
||||
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||
self.assertTrue(self.client._connected)
|
||||
mock_client_info.assert_called_once_with()
|
||||
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
|
||||
|
||||
@patch.object(client, 'input')
|
||||
def test__get_command(self, mock_input: MagicMock):
|
||||
self.client._connected = True
|
||||
|
||||
mock_input.return_value = ' ' + FOO.upper() + ' '
|
||||
mock_close = MagicMock()
|
||||
mock_writer = MagicMock(close=mock_close)
|
||||
output = self.client._get_command(mock_writer)
|
||||
self.assertEqual(FOO, output)
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock_input.reset_mock()
|
||||
mock_input.side_effect = KeyboardInterrupt
|
||||
self.assertIsNone(self.client._get_command(mock_writer))
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock_input.reset_mock()
|
||||
mock_input.side_effect = EOFError
|
||||
self.assertIsNone(self.client._get_command(mock_writer))
|
||||
mock_input.assert_called_once()
|
||||
mock_close.assert_called_once()
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, '_get_command')
|
||||
async def test__interact(self, mock__get_command: MagicMock):
|
||||
self.client._connected = True
|
||||
|
||||
mock__get_command.return_value = None
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_not_called()
|
||||
self.mock_drain.assert_not_awaited()
|
||||
self.mock_read.assert_not_awaited()
|
||||
self.mock_print.assert_not_called()
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||
self.mock_drain.side_effect = err = ConnectionError()
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_called_once_with(cmd.encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
self.client._connected = True
|
||||
self.mock_write.reset_mock()
|
||||
self.mock_drain.reset_mock(side_effect=True)
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||
self.mock_write.assert_called_once_with(cmd.encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
self.mock_print.assert_called_once_with(FOO)
|
||||
self.assertTrue(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, '_interact')
|
||||
@patch.object(client.ControlClient, '_server_handshake')
|
||||
@patch.object(client.ControlClient, '_open_connection')
|
||||
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
|
||||
mock__interact: AsyncMock):
|
||||
mock__open_connection.return_value = None, None
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_not_awaited()
|
||||
mock__interact.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
|
||||
|
||||
mock__open_connection.reset_mock()
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
mock__open_connection.return_value = self.mock_reader, self.mock_writer
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
mock__interact.assert_not_awaited()
|
||||
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||
|
||||
mock__open_connection.reset_mock()
|
||||
mock__server_handshake.reset_mock()
|
||||
self.mock_print.reset_mock()
|
||||
|
||||
self.client._connected = True
|
||||
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
|
||||
mock__interact.side_effect = disconnect
|
||||
self.assertIsNone(await self.client.start())
|
||||
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||
|
||||
|
||||
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
|
||||
self.mock_base_init = self.base_init_patcher.start()
|
||||
self.path = '/tmp/asyncio_taskpool'
|
||||
self.kwargs = {FOO: 123, BAR: 456}
|
||||
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.base_init_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(Path(self.path), self.client._socket_path)
|
||||
self.mock_base_init.assert_called_once_with(**self.kwargs)
|
||||
|
||||
@patch.object(client, 'print')
|
||||
@patch.object(client, 'open_unix_connection')
|
||||
async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock):
|
||||
mock_open_unix_connection.return_value = expected_output = 'something'
|
||||
kwargs = {'a': 1, 'b': 2}
|
||||
output = await self.client._open_connection(**kwargs)
|
||||
self.assertEqual(expected_output, output)
|
||||
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||
mock_print.assert_not_called()
|
||||
|
||||
mock_open_unix_connection.reset_mock()
|
||||
|
||||
mock_open_unix_connection.side_effect = FileNotFoundError
|
||||
output1, output2 = await self.client._open_connection(**kwargs)
|
||||
self.assertIsNone(output1)
|
||||
self.assertIsNone(output2)
|
||||
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)
|
@ -558,19 +558,19 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
|
||||
|
||||
mock_func, stars = MagicMock(), 3
|
||||
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
||||
args_iter, group_size = (FOO, BAR, 1, 2, 3), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
|
||||
self.task_pool._locked = False
|
||||
with self.assertRaises(exceptions.PoolIsLocked):
|
||||
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
|
||||
await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)
|
||||
mock__set_up_args_queue.assert_not_called()
|
||||
mock__queue_consumer.assert_not_awaited()
|
||||
mock_flag_set.assert_not_called()
|
||||
|
||||
self.task_pool._locked = True
|
||||
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
|
||||
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
|
||||
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb))
|
||||
mock__set_up_args_queue.assert_called_once_with(args_iter, group_size)
|
||||
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)])
|
||||
mock_flag_set.assert_called_once_with()
|
||||
@ -578,28 +578,28 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_map(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
||||
arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_starmap(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
args_iter, num_tasks = ([FOO], [BAR]), 2
|
||||
args_iter, group_size = ([FOO], [BAR]), 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
async def test_doublestarmap(self, mock__map: AsyncMock):
|
||||
mock_func = MagicMock()
|
||||
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
|
||||
kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb))
|
||||
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
from asyncio_taskpool import session
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER
|
||||
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||
|
||||
@ -119,7 +119,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
width = 1234
|
||||
expected_parser_kwargs = {
|
||||
'prog': '',
|
||||
SESSION_PARSER_WRITER: self.mock_writer,
|
||||
SESSION_WRITER: self.mock_writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: width,
|
||||
}
|
||||
self.assertIsNone(self.session._init_parser(width))
|
||||
|
134
tests/test_session_parser.py
Normal file
134
tests/test_session_parser.py
Normal file
@ -0,0 +1,134 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.session_parser` module.
|
||||
"""
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import session_parser
|
||||
from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO
|
||||
from asyncio_taskpool.exceptions import HelpRequested
|
||||
|
||||
|
||||
FOO = 'foo'
|
||||
|
||||
|
||||
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory')
|
||||
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
||||
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
||||
self.session_writer, self.terminal_width = MagicMock(), 420
|
||||
self.kwargs = {
|
||||
SESSION_WRITER: self.session_writer,
|
||||
CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width,
|
||||
session_parser.FORMATTER_CLASS: FOO
|
||||
}
|
||||
self.parser = session_parser.CommandParser(**self.kwargs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.help_formatter_factory_patcher.stop()
|
||||
|
||||
def test_help_formatter_factory(self):
|
||||
self.help_formatter_factory_patcher.stop()
|
||||
|
||||
class MockBaseClass(HelpFormatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
terminal_width = 123456789
|
||||
cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass)
|
||||
self.assertTrue(issubclass(cls, MockBaseClass))
|
||||
instance = cls('prog')
|
||||
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||
|
||||
cls = session_parser.CommandParser.help_formatter_factory(terminal_width)
|
||||
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
|
||||
instance = cls('prog')
|
||||
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||
|
||||
def test_init(self):
|
||||
self.assertIsInstance(self.parser, ArgumentParser)
|
||||
self.assertEqual(self.session_writer, self.parser._session_writer)
|
||||
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
||||
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
||||
self.assertFalse(getattr(self.parser, 'exit_on_error'))
|
||||
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
||||
|
||||
def test_session_writer(self):
|
||||
self.assertEqual(self.session_writer, self.parser.session_writer)
|
||||
|
||||
def test_terminal_width(self):
|
||||
self.assertEqual(self.terminal_width, self.parser.terminal_width)
|
||||
|
||||
def test__print_message(self):
|
||||
self.session_writer.write = MagicMock()
|
||||
self.assertIsNone(self.parser._print_message(''))
|
||||
self.session_writer.write.assert_not_called()
|
||||
msg = 'foo bar baz'
|
||||
self.assertIsNone(self.parser._print_message(msg))
|
||||
self.session_writer.write.assert_called_once_with(msg.encode())
|
||||
|
||||
@patch.object(session_parser.CommandParser, '_print_message')
|
||||
def test_exit(self, mock__print_message: MagicMock):
|
||||
self.assertIsNone(self.parser.exit(123, ''))
|
||||
mock__print_message.assert_not_called()
|
||||
msg = 'foo bar baz'
|
||||
self.assertIsNone(self.parser.exit(123, msg))
|
||||
mock__print_message.assert_called_once_with(msg)
|
||||
|
||||
@patch.object(session_parser.ArgumentParser, 'print_help')
|
||||
def test_print_help(self, mock_print_help: MagicMock):
|
||||
arg = MagicMock()
|
||||
with self.assertRaises(HelpRequested):
|
||||
self.parser.print_help(arg)
|
||||
mock_print_help.assert_called_once_with(arg)
|
||||
|
||||
def test_add_optional_num_argument(self):
|
||||
metavar = 'FOOBAR'
|
||||
action = self.parser.add_optional_num_argument(metavar=metavar)
|
||||
self.assertIsInstance(action, Action)
|
||||
self.assertEqual('?', action.nargs)
|
||||
self.assertEqual(1, action.default)
|
||||
self.assertEqual(int, action.type)
|
||||
self.assertEqual(metavar, action.metavar)
|
||||
num = 111
|
||||
kwargs = vars(self.parser.parse_args([f'{num}']))
|
||||
self.assertDictEqual({session_parser.NUM: num}, kwargs)
|
||||
|
||||
name = f'--{FOO}'
|
||||
nargs = '+'
|
||||
default = 1
|
||||
_type = float
|
||||
required = True
|
||||
dest = 'foo_bar'
|
||||
action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type,
|
||||
required=required, metavar=metavar, dest=dest)
|
||||
self.assertIsInstance(action, Action)
|
||||
self.assertEqual(nargs, action.nargs)
|
||||
self.assertEqual(default, action.default)
|
||||
self.assertEqual(_type, action.type)
|
||||
self.assertEqual(required, action.required)
|
||||
self.assertEqual(metavar, action.metavar)
|
||||
self.assertEqual(dest, action.dest)
|
||||
kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5']))
|
||||
self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)
|
@ -130,7 +130,7 @@ async def main() -> None:
|
||||
# **only** once there is room in the pool **and** no more than one of these last four tasks is running.
|
||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||
print("Calling `starmap`...")
|
||||
await pool.starmap(other_work, args_list, num_tasks=2)
|
||||
await pool.starmap(other_work, args_list, group_size=2)
|
||||
print("`starmap` returned")
|
||||
# Now we lock the pool, so that we can safely await all our tasks.
|
||||
pool.lock()
|
||||
|
Loading…
Reference in New Issue
Block a user