generated from daniil-berg/boilerplate-py
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
4c6a5412ca | |||
44c03cc493 | |||
689a74c678 | |||
3503c0bf44 | |||
3d104c979e | |||
a92e646411 |
17
README.md
17
README.md
@ -2,9 +2,18 @@
|
||||
|
||||
**Dynamically manage pools of asyncio tasks**
|
||||
|
||||
## Contents
|
||||
- [Contents](#contents)
|
||||
- [Summary](#summary)
|
||||
- [Usage](#usage)
|
||||
- [Installation](#installation)
|
||||
- [Dependencies](#dependencies)
|
||||
- [Testing](#testing)
|
||||
- [License](#license)
|
||||
|
||||
## Summary
|
||||
|
||||
A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
|
||||
A **task pool** is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
|
||||
|
||||
With an interface that is intentionally similar to the [`multiprocessing.Pool`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool) class from the standard library, the `TaskPool` provides you such methods as `apply`, `map`, and `starmap` to execute coroutines concurrently as [`asyncio.Task`](https://docs.python.org/3/library/asyncio-task.html#task-object) objects. There is no limitation imposed on what kind of tasks can be run or in what combination, when new ones can be added, or when they can be cancelled.
|
||||
|
||||
@ -22,7 +31,7 @@ from asyncio_taskpool import SimpleTaskPool
|
||||
...
|
||||
|
||||
|
||||
async def work(foo, bar): ...
|
||||
async def work(_foo, _bar): ...
|
||||
|
||||
|
||||
...
|
||||
@ -55,7 +64,7 @@ Python Version 3.8+, tested on Linux
|
||||
|
||||
## Testing
|
||||
|
||||
Install `asyncio-taskpool[dev]` dependencies or just manually install `coverage` with `pip`.
|
||||
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
|
||||
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
|
||||
|
||||
## License
|
||||
@ -64,6 +73,6 @@ Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests an
|
||||
|
||||
The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COPYING.LESSER) are included in this repository. If not, see https://www.gnu.org/licenses/.
|
||||
|
||||
## Copyright
|
||||
---
|
||||
|
||||
© 2022 Daniil Fajnberg
|
||||
|
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.6.2
|
||||
version = 0.8.0
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
@ -9,7 +9,7 @@ long_description_content_type = text/markdown
|
||||
keywords = asyncio, concurrency, tasks, coroutines, asynchronous, server
|
||||
url = https://git.fajnberg.de/daniil/asyncio-taskpool
|
||||
project_urls =
|
||||
Bug Tracker = https://git.fajnberg.de/daniil/asyncio-taskpool/issues
|
||||
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
|
||||
classifiers =
|
||||
Development Status :: 3 - Alpha
|
||||
Programming Language :: Python :: 3
|
||||
|
@ -41,7 +41,7 @@ class ControlClient(ABC):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def client_info() -> dict:
|
||||
def _client_info() -> dict:
|
||||
"""Returns a dictionary of client information relevant for the handshake with the server."""
|
||||
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||
|
||||
@ -73,9 +73,10 @@ class ControlClient(ABC):
|
||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||
"""
|
||||
self._connected = True
|
||||
writer.write(json.dumps(self.client_info()).encode())
|
||||
writer.write(json.dumps(self._client_info()).encode())
|
||||
await writer.drain()
|
||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
||||
|
||||
def _get_command(self, writer: StreamWriter) -> Optional[str]:
|
||||
"""
|
||||
|
@ -20,14 +20,16 @@ This module contains the the definition of the `ControlParser` class used by a c
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
|
||||
from ast import literal_eval
|
||||
from asyncio.streams import StreamWriter
|
||||
from inspect import Parameter, getmembers, isfunction, signature
|
||||
from shutil import get_terminal_size
|
||||
from typing import Callable, Container, Dict, Set, Type, TypeVar
|
||||
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
|
||||
|
||||
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
|
||||
from ..exceptions import HelpRequested
|
||||
from ..helpers import get_first_doc_line
|
||||
from ..exceptions import HelpRequested, ParserError
|
||||
from ..helpers import get_first_doc_line, resolve_dotted_path
|
||||
from ..types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
||||
|
||||
|
||||
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||
@ -35,7 +37,6 @@ ParsersDict = Dict[str, 'ControlParser']
|
||||
|
||||
OMIT_PARAMS_DEFAULT = ('self', )
|
||||
|
||||
FORMATTER_CLASS = 'formatter_class'
|
||||
NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description'
|
||||
|
||||
|
||||
@ -79,24 +80,23 @@ class ControlParser(ArgumentParser):
|
||||
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Sets additional internal attributes depending on whether a parent-parser was defined.
|
||||
Subclass of the `ArgumentParser` geared towards asynchronous interaction with an object "from the outside".
|
||||
|
||||
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
|
||||
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
||||
Allows directing output to a specified writer rather than stdout/stderr and setting terminal width explicitly.
|
||||
|
||||
Args:
|
||||
stream_writer:
|
||||
The instance of the `asyncio.StreamWriter` to use for message output.
|
||||
terminal_width (optional):
|
||||
The terminal width to assume for all message formatting. Defaults to `shutil.get_terminal_size`.
|
||||
The terminal width to use for all message formatting. Defaults to `shutil.get_terminal_size().columns`.
|
||||
**kwargs(optional):
|
||||
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
|
||||
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
|
||||
is empty.
|
||||
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
|
||||
class is specified, it will always be subclassed in the `help_formatter_factory`.
|
||||
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
||||
"""
|
||||
self._stream_writer: StreamWriter = stream_writer
|
||||
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
|
||||
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
|
||||
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
|
||||
kwargs.setdefault('exit_on_error', False)
|
||||
super().__init__(**kwargs)
|
||||
self._flags: Set[str] = set()
|
||||
@ -219,7 +219,7 @@ class ControlParser(ArgumentParser):
|
||||
def error(self, message: str) -> None:
|
||||
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
|
||||
super().error(message=message)
|
||||
raise HelpRequested
|
||||
raise ParserError
|
||||
|
||||
def print_help(self, file=None) -> None:
|
||||
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
|
||||
@ -267,9 +267,8 @@ class ControlParser(ArgumentParser):
|
||||
# This is to be able to later unpack an arbitrary number of positional arguments.
|
||||
kwargs.setdefault('nargs', '*')
|
||||
if not kwargs.get('action') == 'store_true':
|
||||
# The lambda wrapper around the type annotation is to avoid ValueError being raised on suppressed arguments.
|
||||
# See: https://bugs.python.org/issue36078
|
||||
kwargs.setdefault('type', get_arg_type_wrapper(parameter.annotation))
|
||||
# Set the type from the parameter annotation.
|
||||
kwargs.setdefault('type', _get_type_from_annotation(parameter.annotation))
|
||||
return self.add_argument(*name_or_flags, **kwargs)
|
||||
|
||||
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
|
||||
@ -293,7 +292,21 @@ class ControlParser(ArgumentParser):
|
||||
self.add_function_arg(param, help=repr(param.annotation))
|
||||
|
||||
|
||||
def get_arg_type_wrapper(cls: Type) -> Callable:
|
||||
def wrapper(arg):
|
||||
return arg if arg is SUPPRESS else cls(arg)
|
||||
def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
|
||||
"""
|
||||
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
|
||||
|
||||
See: https://bugs.python.org/issue36078
|
||||
"""
|
||||
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg)
|
||||
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
|
||||
wrapper.__name__ = cls.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_type_from_annotation(annotation: Type) -> Callable[[Any], Any]:
|
||||
if any(annotation is t for t in {CoroutineFunc, EndCB, CancelCB}):
|
||||
annotation = resolve_dotted_path
|
||||
if any(annotation is t for t in {ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]}):
|
||||
annotation = literal_eval
|
||||
return _get_arg_type_wrapper(annotation)
|
||||
|
@ -37,7 +37,7 @@ from .session import ControlSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||
class ControlServer(ABC):
|
||||
"""
|
||||
Abstract base class for a task pool control server.
|
||||
|
||||
@ -125,6 +125,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
||||
async def serve_forever(self) -> Task:
|
||||
"""
|
||||
This method actually starts the server and begins listening to client connections on the specified interface.
|
||||
|
||||
It should never block because the serving will be performed in a separate task.
|
||||
"""
|
||||
log.debug("Starting %s...", self.__class__.__name__)
|
||||
@ -136,7 +137,7 @@ class TCPControlServer(ControlServer):
|
||||
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
|
||||
_client_class = TCPControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||
self._host = server_kwargs.pop('host')
|
||||
self._port = server_kwargs.pop('port')
|
||||
super().__init__(pool, **server_kwargs)
|
||||
@ -154,7 +155,7 @@ class UnixControlServer(ControlServer):
|
||||
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||
_client_class = UnixControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||
from asyncio.streams import start_unix_server
|
||||
self._start_unix_server = start_unix_server
|
||||
self._socket_path = Path(server_kwargs.pop('path'))
|
||||
|
@ -27,7 +27,7 @@ from inspect import isfunction, signature
|
||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
|
||||
from ..exceptions import CommandError, HelpRequested
|
||||
from ..exceptions import CommandError, HelpRequested, ParserError
|
||||
from ..helpers import return_or_exception
|
||||
from ..pool import TaskPool, SimpleTaskPool
|
||||
from .parser import ControlParser
|
||||
@ -85,7 +85,7 @@ class ControlSession:
|
||||
Must correspond to the arguments expected by the `method`.
|
||||
Correctly unpacks arbitrary-length positional and keyword-arguments.
|
||||
"""
|
||||
log.warning("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
|
||||
log.debug("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
|
||||
normal_pos, var_pos = [], []
|
||||
for param in signature(method).parameters.values():
|
||||
if param.name == 'self':
|
||||
@ -112,11 +112,11 @@ class ControlSession:
|
||||
executed and the response written to the stream will be its return value (as an encoded string).
|
||||
"""
|
||||
if kwargs:
|
||||
log.warning("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
||||
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
||||
await return_or_exception(prop.fset, self._pool, **kwargs)
|
||||
self._writer.write(CMD_OK)
|
||||
else:
|
||||
log.warning("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
||||
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
||||
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
|
||||
|
||||
async def client_handshake(self) -> None:
|
||||
@ -154,9 +154,11 @@ class ControlSession:
|
||||
try:
|
||||
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
||||
except ArgumentError as e:
|
||||
log.debug("%s got an ArgumentError", self._client_class_name)
|
||||
self._writer.write(str(e).encode())
|
||||
return
|
||||
except HelpRequested:
|
||||
except (HelpRequested, ParserError):
|
||||
log.debug("%s received usage help", self._client_class_name)
|
||||
return
|
||||
command = kwargs.pop(CMD)
|
||||
if isfunction(command):
|
||||
|
@ -67,5 +67,9 @@ class HelpRequested(ServerException):
|
||||
pass
|
||||
|
||||
|
||||
class ParserError(ServerException):
|
||||
pass
|
||||
|
||||
|
||||
class CommandError(ServerException):
|
||||
pass
|
||||
|
@ -15,12 +15,13 @@ You should have received a copy of the GNU Lesser General Public License along w
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Miscellaneous helper functions.
|
||||
Miscellaneous helper functions. None of these should be considered part of the public API.
|
||||
"""
|
||||
|
||||
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.queues import Queue
|
||||
from importlib import import_module
|
||||
from inspect import getdoc
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -63,3 +64,22 @@ async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwarg
|
||||
return _function_to_execute(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def resolve_dotted_path(dotted_path: str) -> object:
|
||||
"""
|
||||
Resolves a dotted path to a global object and returns that object.
|
||||
|
||||
Algorithm shamelessly stolen from the `logging.config` module from the standard library.
|
||||
"""
|
||||
names = dotted_path.split('.')
|
||||
module_name = names.pop(0)
|
||||
found = import_module(module_name)
|
||||
for name in names:
|
||||
try:
|
||||
found = getattr(found, name)
|
||||
except AttributeError:
|
||||
module_name += f'.{name}'
|
||||
import_module(module_name)
|
||||
found = getattr(found, name)
|
||||
return found
|
||||
|
@ -59,16 +59,14 @@ class BaseTaskPool:
|
||||
|
||||
@classmethod
|
||||
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
|
||||
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
|
||||
"""Adds a `pool` to the general list of pools and returns it's index."""
|
||||
cls._pools.append(pool)
|
||||
return len(cls._pools) - 1
|
||||
|
||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
||||
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
|
||||
# Initialize a counter for the total number of tasks started through the pool and one for the total number of
|
||||
# tasks cancelled through the pool.
|
||||
# Initialize a counter for the total number of tasks started through the pool.
|
||||
self._num_started: int = 0
|
||||
self._num_cancellations: int = 0
|
||||
|
||||
# Initialize flags; immutably set the name.
|
||||
self._locked: bool = False
|
||||
@ -97,18 +95,18 @@ class BaseTaskPool:
|
||||
|
||||
@property
|
||||
def pool_size(self) -> int:
|
||||
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
|
||||
return self._pool_size
|
||||
"""Maximum number of concurrently running tasks allowed in the pool."""
|
||||
return getattr(self._enough_room, '_value')
|
||||
|
||||
@pool_size.setter
|
||||
def pool_size(self, value: int) -> None:
|
||||
"""
|
||||
Sets the maximum number of concurrently running tasks in the pool.
|
||||
|
||||
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
|
||||
|
||||
Args:
|
||||
value:
|
||||
A non-negative integer.
|
||||
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
|
||||
value: A non-negative integer.
|
||||
|
||||
Raises:
|
||||
`ValueError` if `value` is less than 0.
|
||||
@ -116,11 +114,10 @@ class BaseTaskPool:
|
||||
if value < 0:
|
||||
raise ValueError("Pool size can not be less than 0")
|
||||
self._enough_room._value = value
|
||||
self._pool_size = value
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
"""Returns `True` if the pool has been locked (see below)."""
|
||||
"""`True` if the pool has been locked (see below)."""
|
||||
return self._locked
|
||||
|
||||
def lock(self) -> None:
|
||||
@ -138,26 +135,26 @@ class BaseTaskPool:
|
||||
@property
|
||||
def num_running(self) -> int:
|
||||
"""
|
||||
Returns the number of tasks in the pool that are (at that moment) still running.
|
||||
Number of tasks in the pool that are still running.
|
||||
|
||||
At the moment a task's `end_callback` or `cancel_callback` is fired, it is no longer considered running.
|
||||
"""
|
||||
return len(self._tasks_running)
|
||||
|
||||
@property
|
||||
def num_cancellations(self) -> int:
|
||||
def num_cancelled(self) -> int:
|
||||
"""
|
||||
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
|
||||
Number of tasks in the pool that have been cancelled.
|
||||
|
||||
At the moment a task's `cancel_callback` is fired, this counts as a cancellation, and the task is then
|
||||
considered cancelled (instead of running) until its `end_callback` is fired.
|
||||
At the moment a task's `cancel_callback` is fired, it is considered to be cancelled and no longer running,
|
||||
until its `end_callback` is fired, at which point it is considered ended (instead of cancelled).
|
||||
"""
|
||||
return self._num_cancellations
|
||||
return len(self._tasks_cancelled)
|
||||
|
||||
@property
|
||||
def num_ended(self) -> int:
|
||||
"""
|
||||
Returns the number of tasks started through the pool that have stopped running (up until that moment).
|
||||
Number of tasks in the pool that have stopped running.
|
||||
|
||||
At the moment a task's `end_callback` is fired, it is considered ended and no longer running (or cancelled).
|
||||
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
|
||||
@ -165,16 +162,12 @@ class BaseTaskPool:
|
||||
"""
|
||||
return len(self._tasks_ended)
|
||||
|
||||
@property
|
||||
def num_finished(self) -> int:
|
||||
"""Returns the number of tasks in the pool that have finished running (without having been cancelled)."""
|
||||
return len(self._tasks_ended) - self._num_cancellations + len(self._tasks_cancelled)
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
"""
|
||||
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
|
||||
When the pool is full, any call to start a new task within it will block.
|
||||
`False` if the number of running tasks is less than the `pool_size`.
|
||||
|
||||
When the pool is full, any call to start a new task within it will block, until there is enough room for it.
|
||||
"""
|
||||
return self._enough_room.locked()
|
||||
|
||||
@ -247,7 +240,6 @@ class BaseTaskPool:
|
||||
"""
|
||||
log.debug("Cancelling %s ...", self._task_name(task_id))
|
||||
self._tasks_cancelled[task_id] = self._tasks_running.pop(task_id)
|
||||
self._num_cancellations += 1
|
||||
log.debug("Cancelled %s", self._task_name(task_id))
|
||||
await execute_optional(custom_callback, args=(task_id,))
|
||||
|
||||
@ -276,7 +268,9 @@ class BaseTaskPool:
|
||||
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCB = None,
|
||||
cancel_callback: CancelCB = None) -> Any:
|
||||
"""
|
||||
Universal wrapper around every task run in the pool that returns/raises whatever the wrapped coroutine does.
|
||||
Universal wrapper around every task run in the pool.
|
||||
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Responsible for catching cancellation and awaiting the `_task_cancellation` callback, as well as for awaiting
|
||||
the `_task_ending` callback, after the coroutine returns or raises an exception.
|
||||
@ -381,7 +375,9 @@ class BaseTaskPool:
|
||||
|
||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
||||
"""
|
||||
Removes all tasks from the specified group and cancels them, if they are still running.
|
||||
Removes all tasks from the specified group and cancels them.
|
||||
|
||||
Does nothing to tasks, that are no longer running.
|
||||
|
||||
Args:
|
||||
group_name: The name of the group of tasks that shall be cancelled.
|
||||
@ -397,7 +393,9 @@ class BaseTaskPool:
|
||||
|
||||
async def cancel_group(self, group_name: str, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
|
||||
Cancels an entire group of tasks.
|
||||
|
||||
The task group is subsequently forgotten by the pool.
|
||||
|
||||
Args:
|
||||
group_name: The name of the group of tasks that shall be cancelled.
|
||||
@ -430,12 +428,13 @@ class BaseTaskPool:
|
||||
|
||||
async def flush(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks in the pool.
|
||||
|
||||
This method exists mainly to free up memory of unneeded `Task` objects.
|
||||
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
|
||||
`Task` objects.
|
||||
|
||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
|
||||
registered for the tasks block.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
@ -446,7 +445,9 @@ class BaseTaskPool:
|
||||
|
||||
async def gather_and_close(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
|
||||
Calls `asyncio.gather` on **all** tasks in the pool, then closes it.
|
||||
|
||||
After this method returns, no more tasks can be started in the pool.
|
||||
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
@ -473,7 +474,9 @@ class BaseTaskPool:
|
||||
|
||||
class TaskPool(BaseTaskPool):
|
||||
"""
|
||||
General task pool class. Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
|
||||
General purpose task pool class.
|
||||
|
||||
Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
|
||||
|
||||
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
|
||||
Tasks in the pool can all belong to the same coroutine function,
|
||||
@ -506,12 +509,15 @@ class TaskPool(BaseTaskPool):
|
||||
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
||||
|
||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
||||
"""See base class."""
|
||||
self._cancel_group_meta_tasks(group_name)
|
||||
super()._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
||||
|
||||
async def cancel_group(self, group_name: str, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
|
||||
Cancels an entire group of tasks.
|
||||
|
||||
The task group is subsequently forgotten by the pool.
|
||||
|
||||
If any methods such as `map()` launched meta tasks belonging to that group, these meta tasks are cancelled
|
||||
before the actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
|
||||
@ -529,7 +535,7 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
async def cancel_all(self, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels all tasks still running within the pool. (This includes all meta tasks.)
|
||||
Cancels all tasks still running within the pool (including meta tasks).
|
||||
|
||||
If any methods such as `map()` launched meta tasks, these meta tasks are cancelled before the actual tasks are
|
||||
cancelled. This means that any tasks "queued" to be started by a meta task will **never even start**. In the
|
||||
@ -569,12 +575,13 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
async def flush(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks in the pool.
|
||||
|
||||
This method exists mainly to free up memory of unneeded `Task` objects. It also gets rid of unneeded meta tasks.
|
||||
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
|
||||
`Task` objects. It also gets rid of unneeded meta tasks.
|
||||
|
||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
|
||||
registered for the tasks block.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
@ -587,7 +594,9 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
async def gather_and_close(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
|
||||
Calls `asyncio.gather` on **all** tasks in the pool, then closes it.
|
||||
|
||||
After this method returns, no more tasks can be started in the pool.
|
||||
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
@ -596,7 +605,6 @@ class TaskPool(BaseTaskPool):
|
||||
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
|
||||
this, make sure to call `cancel_all()` prior to this.
|
||||
|
||||
|
||||
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
||||
the callbacks registered for a task blocks for whatever reason.
|
||||
|
||||
@ -662,9 +670,13 @@ class TaskPool(BaseTaskPool):
|
||||
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
|
||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||
"""
|
||||
Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool.
|
||||
Creates tasks with the supplied arguments to be run in the pool.
|
||||
|
||||
Each coroutine looks like `func(*args, **kwargs)`, meaning the `args` and `kwargs` are unpacked and passed
|
||||
into `func` before creating each task, and this is done `num` times.
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
Each coroutine looks like `func(*args, **kwargs)`. All the new tasks are added to the same task group.
|
||||
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
||||
|
||||
Args:
|
||||
@ -775,15 +787,23 @@ class TaskPool(BaseTaskPool):
|
||||
if next_arg is self._QUEUE_END_SENTINEL:
|
||||
# The `_queue_producer()` either reached the last argument or was cancelled.
|
||||
return
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
log.exception("%s occurred while trying to create task: %s(%s%s)",
|
||||
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
|
||||
map_semaphore.release()
|
||||
|
||||
async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
"""
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
Creates tasks in the pool with arguments from the supplied iterable.
|
||||
|
||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
@ -840,10 +860,11 @@ class TaskPool(BaseTaskPool):
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||
"""
|
||||
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||
|
||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
@ -939,6 +960,7 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None,
|
||||
pool_size: int = inf, name: str = None) -> None:
|
||||
"""
|
||||
Initializes all required attributes.
|
||||
|
||||
Args:
|
||||
func:
|
||||
@ -957,6 +979,9 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
The maximum number of tasks allowed to run concurrently in the pool
|
||||
name (optional):
|
||||
An optional name for the pool.
|
||||
|
||||
Raises:
|
||||
`NotCoroutine` if `func` is not a coroutine function.
|
||||
"""
|
||||
if not iscoroutinefunction(func):
|
||||
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
|
||||
@ -969,7 +994,7 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
"""Returns the name of the coroutine function used in the pool."""
|
||||
"""Name of the coroutine function used in the pool."""
|
||||
return self._func.__name__
|
||||
|
||||
async def _start_one(self) -> int:
|
||||
@ -977,18 +1002,18 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
return await self._start_task(self._func(*self._args, **self._kwargs),
|
||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
||||
|
||||
async def start(self, num: int = 1) -> List[int]:
|
||||
"""Starts `num` new tasks within the pool and returns their IDs as a list."""
|
||||
async def start(self, num: int) -> List[int]:
|
||||
"""Starts `num` new tasks within the pool and returns their IDs."""
|
||||
ids = await gather(*(self._start_one() for _ in range(num)))
|
||||
assert isinstance(ids, list) # for PyCharm (see above to-do-item)
|
||||
assert isinstance(ids, list) # for PyCharm
|
||||
return ids
|
||||
|
||||
def stop(self, num: int = 1) -> List[int]:
|
||||
def stop(self, num: int) -> List[int]:
|
||||
"""
|
||||
Cancels `num` running tasks within the pool and returns their IDs as a list.
|
||||
Cancels `num` running tasks within the pool and returns their IDs.
|
||||
|
||||
The tasks are canceled in LIFO order, meaning tasks started later will be stopped before those started earlier.
|
||||
If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled.
|
||||
If `num` is greater than or equal to the number of currently running tasks, all tasks are cancelled.
|
||||
"""
|
||||
ids = []
|
||||
for i, task_id in enumerate(reversed(self._tasks_running)):
|
||||
@ -999,5 +1024,5 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
return ids
|
||||
|
||||
def stop_all(self) -> List[int]:
|
||||
"""Cancels all running tasks and returns their IDs as a list."""
|
||||
"""Cancels all running tasks and returns their IDs."""
|
||||
return self.stop(self.num_running)
|
||||
|
@ -53,6 +53,6 @@ class Queue(_Queue):
|
||||
Implements an asynchronous context manager for the queue.
|
||||
|
||||
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
|
||||
but in some situations it makes the codes much cleaner.
|
||||
but in some situations it makes the code much cleaner.
|
||||
"""
|
||||
self.item_processed()
|
||||
|
@ -25,7 +25,7 @@ import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase, skipIf
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
from asyncio_taskpool.control import client
|
||||
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
|
||||
@ -55,7 +55,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def test_client_info(self):
|
||||
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
|
||||
client.ControlClient.client_info())
|
||||
client.ControlClient._client_info())
|
||||
|
||||
async def test_abstract(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@ -65,16 +65,19 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.kwargs, self.client._conn_kwargs)
|
||||
self.assertFalse(self.client._connected)
|
||||
|
||||
@patch.object(client.ControlClient, 'client_info')
|
||||
async def test__server_handshake(self, mock_client_info: MagicMock):
|
||||
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||
@patch.object(client.ControlClient, '_client_info')
|
||||
async def test__server_handshake(self, mock__client_info: MagicMock):
|
||||
mock__client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||
self.assertTrue(self.client._connected)
|
||||
mock_client_info.assert_called_once_with()
|
||||
mock__client_info.assert_called_once_with()
|
||||
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||
self.mock_drain.assert_awaited_once_with()
|
||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
|
||||
self.mock_print.assert_has_calls([
|
||||
call("Connected to", self.mock_read.return_value.decode()),
|
||||
call("Type '-h' to get help and usage instructions for all available commands.\n")
|
||||
])
|
||||
|
||||
@patch.object(client, 'input')
|
||||
def test__get_command(self, mock_input: MagicMock):
|
||||
|
@ -20,12 +20,16 @@ Unittests for the `asyncio_taskpool.control.parser` module.
|
||||
|
||||
|
||||
from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS
|
||||
from ast import literal_eval
|
||||
from inspect import signature
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from typing import Iterable
|
||||
|
||||
from asyncio_taskpool.control import parser
|
||||
from asyncio_taskpool.exceptions import HelpRequested
|
||||
from asyncio_taskpool.exceptions import HelpRequested, ParserError
|
||||
from asyncio_taskpool.helpers import resolve_dotted_path
|
||||
from asyncio_taskpool.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
@ -41,7 +45,7 @@ class ControlServerTestCase(TestCase):
|
||||
self.kwargs = {
|
||||
'stream_writer': self.stream_writer,
|
||||
'terminal_width': self.terminal_width,
|
||||
parser.FORMATTER_CLASS: FOO
|
||||
'formatter_class': FOO
|
||||
}
|
||||
self.parser = parser.ControlParser(**self.kwargs)
|
||||
|
||||
@ -183,7 +187,7 @@ class ControlServerTestCase(TestCase):
|
||||
|
||||
@patch.object(parser.ArgumentParser, 'error')
|
||||
def test_error(self, mock_supercls_error: MagicMock):
|
||||
with self.assertRaises(HelpRequested):
|
||||
with self.assertRaises(ParserError):
|
||||
self.parser.error(FOO + BAR)
|
||||
mock_supercls_error.assert_called_once_with(message=FOO + BAR)
|
||||
|
||||
@ -194,11 +198,11 @@ class ControlServerTestCase(TestCase):
|
||||
self.parser.print_help(arg)
|
||||
mock_print_help.assert_called_once_with(arg)
|
||||
|
||||
@patch.object(parser, 'get_arg_type_wrapper')
|
||||
@patch.object(parser, '_get_type_from_annotation')
|
||||
@patch.object(parser.ArgumentParser, 'add_argument')
|
||||
def test_add_function_arg(self, mock_add_argument: MagicMock, mock_get_arg_type_wrapper: MagicMock):
|
||||
def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_type_from_annotation: MagicMock):
|
||||
mock_add_argument.return_value = expected_output = 'action'
|
||||
mock_get_arg_type_wrapper.return_value = mock_type = 'fake'
|
||||
mock__get_type_from_annotation.return_value = mock_type = 'fake'
|
||||
|
||||
foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex
|
||||
bar_default, baz_default, boo_default = 1, 0.1, 1j
|
||||
@ -211,42 +215,42 @@ class ControlServerTestCase(TestCase):
|
||||
kwargs = {FOO + BAR: 'xyz'}
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_called_once_with(foo_type)
|
||||
mock__get_type_from_annotation.assert_called_once_with(foo_type)
|
||||
|
||||
mock_add_argument.reset_mock()
|
||||
mock_get_arg_type_wrapper.reset_mock()
|
||||
mock__get_type_from_annotation.reset_mock()
|
||||
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_called_once_with(args_type)
|
||||
mock__get_type_from_annotation.assert_called_once_with(args_type)
|
||||
|
||||
mock_add_argument.reset_mock()
|
||||
mock_get_arg_type_wrapper.reset_mock()
|
||||
mock__get_type_from_annotation.reset_mock()
|
||||
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_called_once_with(bar_type)
|
||||
mock__get_type_from_annotation.assert_called_once_with(bar_type)
|
||||
|
||||
mock_add_argument.reset_mock()
|
||||
mock_get_arg_type_wrapper.reset_mock()
|
||||
mock__get_type_from_annotation.reset_mock()
|
||||
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_called_once_with(baz_type)
|
||||
mock__get_type_from_annotation.assert_called_once_with(baz_type)
|
||||
|
||||
mock_add_argument.reset_mock()
|
||||
mock_get_arg_type_wrapper.reset_mock()
|
||||
mock__get_type_from_annotation.reset_mock()
|
||||
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_called_once_with(boo_type)
|
||||
mock__get_type_from_annotation.assert_called_once_with(boo_type)
|
||||
|
||||
mock_add_argument.reset_mock()
|
||||
mock_get_arg_type_wrapper.reset_mock()
|
||||
mock__get_type_from_annotation.reset_mock()
|
||||
|
||||
self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs))
|
||||
mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs)
|
||||
mock_get_arg_type_wrapper.assert_not_called()
|
||||
mock__get_type_from_annotation.assert_not_called()
|
||||
|
||||
@patch.object(parser.ControlParser, 'add_function_arg')
|
||||
def test_add_function_args(self, mock_add_function_arg: MagicMock):
|
||||
@ -261,7 +265,25 @@ class ControlServerTestCase(TestCase):
|
||||
|
||||
|
||||
class RestTestCase(TestCase):
|
||||
def test_get_arg_type_wrapper(self):
|
||||
type_wrap = parser.get_arg_type_wrapper(int)
|
||||
def test__get_arg_type_wrapper(self):
|
||||
type_wrap = parser._get_arg_type_wrapper(int)
|
||||
self.assertEqual('int', type_wrap.__name__)
|
||||
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
|
||||
self.assertEqual(13, type_wrap('13'))
|
||||
|
||||
@patch.object(parser, '_get_arg_type_wrapper')
|
||||
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
|
||||
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR
|
||||
dotted_path_ann = [CoroutineFunc, EndCB, CancelCB]
|
||||
literal_eval_ann = [ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]]
|
||||
any_other_ann = MagicMock()
|
||||
for a in dotted_path_ann:
|
||||
self.assertEqual(expected_output, parser._get_type_from_annotation(a))
|
||||
mock__get_arg_type_wrapper.assert_has_calls(len(dotted_path_ann) * [call(resolve_dotted_path)])
|
||||
mock__get_arg_type_wrapper.reset_mock()
|
||||
for a in literal_eval_ann:
|
||||
self.assertEqual(expected_output, parser._get_type_from_annotation(a))
|
||||
mock__get_arg_type_wrapper.assert_has_calls(len(literal_eval_ann) * [call(literal_eval)])
|
||||
mock__get_arg_type_wrapper.reset_mock()
|
||||
self.assertEqual(expected_output, parser._get_type_from_annotation(any_other_ann))
|
||||
mock__get_arg_type_wrapper.assert_called_once_with(any_other_ann)
|
||||
|
@ -20,7 +20,7 @@ Unittests for the `asyncio_taskpool.helpers` module.
|
||||
|
||||
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock
|
||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
||||
|
||||
from asyncio_taskpool import helpers
|
||||
|
||||
@ -118,3 +118,13 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
||||
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||
self.assertEqual(test_exception, output)
|
||||
mock_func.assert_called_once_with(*args, **kwargs)
|
||||
|
||||
def test_resolve_dotted_path(self):
|
||||
from logging import WARNING
|
||||
from urllib.request import urlopen
|
||||
self.assertEqual(WARNING, helpers.resolve_dotted_path('logging.WARNING'))
|
||||
self.assertEqual(urlopen, helpers.resolve_dotted_path('urllib.request.urlopen'))
|
||||
with patch.object(helpers, 'import_module', return_value=object) as mock_import_module:
|
||||
with self.assertRaises(AttributeError):
|
||||
helpers.resolve_dotted_path('foo.bar.baz')
|
||||
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
|
||||
|
@ -84,7 +84,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(0, self.task_pool._num_started)
|
||||
self.assertEqual(0, self.task_pool._num_cancellations)
|
||||
|
||||
self.assertFalse(self.task_pool._locked)
|
||||
self.assertFalse(self.task_pool._closed)
|
||||
@ -114,14 +113,14 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
|
||||
def test_pool_size(self):
|
||||
self.pool_size_patcher.stop()
|
||||
self.task_pool._pool_size = self.TEST_POOL_SIZE
|
||||
self.task_pool._enough_room._value = self.TEST_POOL_SIZE
|
||||
self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.task_pool.pool_size = -1
|
||||
|
||||
self.task_pool.pool_size = new_size = 69
|
||||
self.assertEqual(new_size, self.task_pool._pool_size)
|
||||
self.assertEqual(new_size, self.task_pool._enough_room._value)
|
||||
|
||||
def test_is_locked(self):
|
||||
self.task_pool._locked = FOO
|
||||
@ -145,21 +144,14 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
|
||||
self.assertEqual(3, self.task_pool.num_running)
|
||||
|
||||
def test_num_cancellations(self):
|
||||
self.task_pool._num_cancellations = 3
|
||||
self.assertEqual(3, self.task_pool.num_cancellations)
|
||||
def test_num_cancelled(self):
|
||||
self.task_pool._tasks_cancelled = {1: FOO, 2: BAR, 3: BAZ}
|
||||
self.assertEqual(3, self.task_pool.num_cancelled)
|
||||
|
||||
def test_num_ended(self):
|
||||
self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
|
||||
self.assertEqual(3, self.task_pool.num_ended)
|
||||
|
||||
def test_num_finished(self):
|
||||
self.task_pool._num_cancellations = num_cancellations = 69
|
||||
num_ended = 420
|
||||
self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
|
||||
self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
|
||||
self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
|
||||
|
||||
def test_is_full(self):
|
||||
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
|
||||
|
||||
@ -200,12 +192,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
|
||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||
self.task_pool._num_cancellations = cancelled = 3
|
||||
self.task_pool._tasks_running[task_id] = mock_task
|
||||
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._tasks_running)
|
||||
self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
|
||||
self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
|
||||
@ -603,28 +593,34 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
|
||||
@patch.object(pool, 'star_function')
|
||||
@patch.object(pool.TaskPool, '_start_task')
|
||||
@patch.object(pool, 'Semaphore')
|
||||
@patch.object(pool.TaskPool, '_get_map_end_callback')
|
||||
async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
|
||||
@patch.object(pool, 'Semaphore')
|
||||
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
|
||||
mock__start_task: AsyncMock, mock_star_function: MagicMock):
|
||||
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
|
||||
mock_star_function.return_value = awaitable = 'totally an awaitable'
|
||||
arg1, arg2 = 123456789, 'function argument'
|
||||
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||
awaitable = 'totally an awaitable'
|
||||
mock_star_function.side_effect = [awaitable, awaitable, Exception()]
|
||||
arg1, arg2, bad = 123456789, 'function argument', None
|
||||
mock_q_maxsize = 3
|
||||
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
|
||||
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]),
|
||||
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
|
||||
group_name, mock_func, stars = 'whatever', MagicMock(), 3
|
||||
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
|
||||
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then
|
||||
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3,
|
||||
# at the end of the loop, we expect it be locked.
|
||||
self.assertTrue(semaphore.locked())
|
||||
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_has_awaits(2 * [
|
||||
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||
])
|
||||
mock_star_function.assert_has_calls([
|
||||
call(mock_func, arg1, arg_stars=stars),
|
||||
call(mock_func, arg2, arg_stars=stars)
|
||||
call(mock_func, arg2, arg_stars=stars),
|
||||
call(mock_func, bad, arg_stars=stars)
|
||||
])
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
|
@ -1,14 +1,18 @@
|
||||
# Using `asyncio-taskpool`
|
||||
|
||||
## Contents
|
||||
- [Contents](#contents)
|
||||
- [Minimal example for `SimpleTaskPool`](#minimal-example-for-simpletaskpool)
|
||||
- [Advanced example for `TaskPool`](#advanced-example-for-taskpool)
|
||||
- [Control server example](#control-server-example)
|
||||
|
||||
## Minimal example for `SimpleTaskPool`
|
||||
|
||||
With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run.
|
||||
|
||||
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
|
||||
|
||||
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
|
||||
|
||||
### Code
|
||||
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
|
||||
|
||||
```python
|
||||
import logging
|
||||
@ -48,7 +52,9 @@ if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Output
|
||||
<details>
|
||||
<summary>Output: (Click to expand)</summary>
|
||||
|
||||
```
|
||||
SimpleTaskPool-0 initialized
|
||||
Started SimpleTaskPool-0_Task-0
|
||||
@ -78,6 +84,7 @@ Ended SimpleTaskPool-0_Task-1
|
||||
> did 4
|
||||
> did 4
|
||||
```
|
||||
</details>
|
||||
|
||||
## Advanced example for `TaskPool`
|
||||
|
||||
@ -85,9 +92,7 @@ This time, we want to start tasks from _different_ coroutine functions **and** w
|
||||
|
||||
As with the simple example, we need "worker" coroutine functions that can do something asynchronously, as well as a main coroutine function that sets up the pool, starts the tasks, and eventually awaits them.
|
||||
|
||||
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
|
||||
|
||||
### Code
|
||||
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
|
||||
|
||||
```python
|
||||
import logging
|
||||
@ -144,10 +149,9 @@ if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Output
|
||||
Additional comments for the output are provided with `<---` next to the output lines.
|
||||
<details>
|
||||
<summary>Output: (Click to expand)</summary>
|
||||
|
||||
(Keep in mind that the logger and `print` asynchronously write to `stdout`.)
|
||||
```
|
||||
TaskPool-0 initialized
|
||||
Started TaskPool-0_Task-0
|
||||
@ -229,4 +233,37 @@ Ended TaskPool-0_Task-5
|
||||
> Done.
|
||||
```
|
||||
|
||||
(Added comments with `<---` next to the output lines.)
|
||||
|
||||
Keep in mind that the logger and `print` asynchronously write to `stdout`, so the order of lines in your output may be slightly different.
|
||||
</details>
|
||||
|
||||
## Control server example
|
||||
|
||||
One of the main features of `asyncio-taskpool` is the ability to control a task pool "from the outside" at runtime.
|
||||
|
||||
The [example_server.py](./example_server.py) script launches a couple of worker tasks within a `SimpleTaskPool` instance and then starts a `TCPControlServer` instance for that task pool. The server is configured to locally bind to port `9999` and is stopped automatically after the "work" is done.
|
||||
|
||||
To run the script:
|
||||
```shell
|
||||
python usage/example_server.py
|
||||
```
|
||||
|
||||
You can then connect to the server via the command line interface:
|
||||
```shell
|
||||
python -m asyncio_taskpool.control tcp localhost 9999
|
||||
```
|
||||
|
||||
The CLI starts a `TCPControlClient` that connects to our example server. Once the connection is established, it gives you an input prompt allowing you to issue commands to the task pool:
|
||||
```
|
||||
Connected to SimpleTaskPool-0
|
||||
Type '-h' to get help and usage instructions for all available commands.
|
||||
|
||||
>
|
||||
```
|
||||
|
||||
It may be useful to run the server script and the client interface in two separate terminal windows side by side. The server script is configured with a verbose logger and will react to any commands issued by the client with detailed log messages in the terminal.
|
||||
|
||||
---
|
||||
|
||||
© 2022 Daniil Fajnberg
|
||||
|
@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Working example of a UnixControlServer in combination with the SimpleTaskPool.
|
||||
Working example of a TCPControlServer in combination with the SimpleTaskPool.
|
||||
Use the main CLI client to interface at the socket.
|
||||
"""
|
||||
|
||||
@ -65,12 +65,12 @@ async def main() -> None:
|
||||
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
|
||||
for item in range(100):
|
||||
q.put_nowait(item)
|
||||
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
|
||||
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
|
||||
await pool.start(3) # launches three worker tasks
|
||||
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
|
||||
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||
await q.join()
|
||||
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
|
||||
# Since we don't need any "work" done anymore, we can get rid of our control server by cancelling the task.
|
||||
control_server_task.cancel()
|
||||
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
||||
# we can now safely cancel their tasks.
|
||||
|
Reference in New Issue
Block a user