generated from daniil-berg/boilerplate-py
Compare commits
8 Commits
v1.0.0-bet
...
82e6ca7b1a
Author | SHA1 | Date | |
---|---|---|---|
82e6ca7b1a
|
|||
153127e028
|
|||
17539e9c27
|
|||
1beb9fc9b0
|
|||
23a4cb028a
|
|||
54e5bfa8a0
|
|||
0e7e92a91b
|
|||
a9011076c4
|
@ -43,7 +43,6 @@ async def main():
|
||||
...
|
||||
pool.stop(3)
|
||||
...
|
||||
pool.lock()
|
||||
await pool.gather_and_close()
|
||||
...
|
||||
```
|
||||
|
@ -96,9 +96,9 @@ When you are dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.T
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
> map mypackage.mymodule.worker ['x','y','z'] -g 3
|
||||
> map mypackage.mymodule.worker ['x','y','z'] -n 3
|
||||
|
||||
The :code:`-g` is a shorthand for :code:`--group-size` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
|
||||
The :code:`-n` is a shorthand for :code:`--num-concurrent` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -46,7 +46,7 @@ Let's take a look at an example. Say you have a coroutine function that takes tw
|
||||
async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None:
|
||||
while True:
|
||||
item = await in_queue.get()
|
||||
... # Do some work on the item amd arrive at a result.
|
||||
... # Do some work on the item and arrive at a result.
|
||||
await out_queue.put(result)
|
||||
|
||||
How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this:
|
||||
@ -141,9 +141,8 @@ Or we could use a task pool:
|
||||
async def main():
|
||||
...
|
||||
pool = TaskPool()
|
||||
await pool.map(another_worker_function, data_iterator, group_size=5)
|
||||
await pool.map(another_worker_function, data_iterator, num_concurrent=5)
|
||||
...
|
||||
pool.lock()
|
||||
await pool.gather_and_close()
|
||||
|
||||
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
|
||||
@ -229,7 +228,6 @@ The only method of a pool that one should **always** assume to be blocking is :p
|
||||
|
||||
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
|
||||
|
||||
In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, both :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` and :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased).
|
||||
|
||||
:py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since it makes use of "meta-tasks" under the hood, it will always return immediately. However, if the pool was full when it was called, there is **no guarantee** that even a single task has started, when the method returns.
|
||||
In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased).
|
||||
|
||||
:py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since they make use of "meta-tasks" under the hood, they will always return immediately. However, if the pool was full when one of them was called, there is **no guarantee** that even a single task has started, when the method returns.
|
||||
|
@ -20,7 +20,8 @@ Definition of the :class:`ControlParser` used in a
|
||||
"""
|
||||
|
||||
|
||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
|
||||
import logging
|
||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
|
||||
from ast import literal_eval
|
||||
from asyncio.streams import StreamWriter
|
||||
from inspect import Parameter, getmembers, isfunction, signature
|
||||
@ -36,6 +37,9 @@ from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
||||
__all__ = ['ControlParser']
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||
ParsersDict = Dict[str, 'ControlParser']
|
||||
|
||||
@ -300,8 +304,21 @@ def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
|
||||
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
|
||||
|
||||
See: https://bugs.python.org/issue36078
|
||||
|
||||
In addition, the type conversion wrapper catches exceptions not handled properly by the parser, logs them, and
|
||||
turns them into `ArgumentTypeError` exceptions the parser can propagate to the client.
|
||||
"""
|
||||
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg)
|
||||
def wrapper(arg: Any) -> Any:
|
||||
if arg is SUPPRESS:
|
||||
return arg
|
||||
try:
|
||||
return cls(arg)
|
||||
except (ArgumentTypeError, TypeError, ValueError):
|
||||
raise # handled properly by the parser and propagated to the client anyway
|
||||
except Exception as e:
|
||||
text = f"{e.__class__.__name__} occurred in parser trying to convert type: {cls.__name__}({repr(arg)})"
|
||||
log.exception(text)
|
||||
raise ArgumentTypeError(text) # propagate to the client
|
||||
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
|
||||
wrapper.__name__ = cls.__name__
|
||||
return wrapper
|
||||
|
@ -51,10 +51,6 @@ class InvalidGroupName(PoolException):
|
||||
pass
|
||||
|
||||
|
||||
class PoolStillUnlocked(PoolException):
|
||||
pass
|
||||
|
||||
|
||||
class NotCoroutine(PoolException):
|
||||
pass
|
||||
|
||||
|
@ -25,8 +25,6 @@ PACKAGE_NAME = 'asyncio_taskpool'
|
||||
|
||||
DEFAULT_TASK_GROUP = 'default'
|
||||
|
||||
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
|
||||
|
||||
SESSION_MSG_BYTES = 1024 * 100
|
||||
|
||||
STREAM_WRITER = 'stream_writer'
|
||||
|
@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
||||
|
||||
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.queues import Queue
|
||||
from importlib import import_module
|
||||
from inspect import getdoc
|
||||
from typing import Any, Optional, Union
|
||||
@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
||||
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
|
||||
|
||||
|
||||
async def join_queue(q: Queue) -> None:
|
||||
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
|
||||
await q.join()
|
||||
|
||||
|
||||
def get_first_doc_line(obj: object) -> str:
|
||||
"""Takes an object and returns the first (non-empty) line of its docstring."""
|
||||
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||
|
@ -31,18 +31,15 @@ import logging
|
||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.locks import Semaphore
|
||||
from asyncio.queues import QueueEmpty
|
||||
from asyncio.tasks import Task, create_task, gather
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from math import inf
|
||||
from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set, Union
|
||||
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
||||
|
||||
from . import exceptions
|
||||
from .queue_context import Queue
|
||||
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
|
||||
from .internals.constants import DEFAULT_TASK_GROUP
|
||||
from .internals.group_register import TaskGroupRegister
|
||||
from .internals.helpers import execute_optional, star_function, join_queue
|
||||
from .internals.helpers import execute_optional, star_function
|
||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||
|
||||
|
||||
@ -82,8 +79,7 @@ class BaseTaskPool:
|
||||
self._tasks_cancelled: Dict[int, Task] = {}
|
||||
self._tasks_ended: Dict[int, Task] = {}
|
||||
|
||||
# These next three attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._before_gathering: List[Awaitable] = []
|
||||
# These next two attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._enough_room: Semaphore = Semaphore()
|
||||
self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
|
||||
|
||||
@ -453,9 +449,7 @@ class BaseTaskPool:
|
||||
"""
|
||||
Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
|
||||
|
||||
After this method returns, no more tasks can be started in the pool.
|
||||
|
||||
:meth:`lock` must have been called prior to this.
|
||||
Once this method is called, no more tasks can be started in the pool.
|
||||
|
||||
This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the
|
||||
callbacks registered for a task blocks for whatever reason.
|
||||
@ -466,15 +460,12 @@ class BaseTaskPool:
|
||||
Raises:
|
||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
||||
"""
|
||||
if not self._locked:
|
||||
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
|
||||
await gather(*self._before_gathering)
|
||||
self.lock()
|
||||
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._tasks_ended.clear()
|
||||
self._tasks_cancelled.clear()
|
||||
self._tasks_running.clear()
|
||||
self._before_gathering.clear()
|
||||
self._closed = True
|
||||
|
||||
|
||||
@ -494,12 +485,10 @@ class TaskPool(BaseTaskPool):
|
||||
Adding tasks blocks **only if** the pool is full at that moment.
|
||||
"""
|
||||
|
||||
_QUEUE_END_SENTINEL = object()
|
||||
|
||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
||||
super().__init__(pool_size=pool_size, name=name)
|
||||
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
|
||||
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
|
||||
# meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
|
||||
self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
|
||||
self._meta_tasks_cancelled: Set[Task] = set()
|
||||
|
||||
@ -592,24 +581,21 @@ class TaskPool(BaseTaskPool):
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
"""
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
|
||||
async def gather_and_close(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
|
||||
|
||||
After this method returns, no more tasks can be started in the pool.
|
||||
|
||||
The `lock()` method must have been called prior to this.
|
||||
Once this method is called, no more tasks can be started in the pool.
|
||||
|
||||
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
|
||||
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed,
|
||||
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
|
||||
this, make sure to call :meth:`cancel_all` prior to this.
|
||||
tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
|
||||
fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
|
||||
|
||||
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
||||
the callbacks registered for a task blocks for whatever reason.
|
||||
@ -620,29 +606,32 @@ class TaskPool(BaseTaskPool):
|
||||
Raises:
|
||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
||||
"""
|
||||
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`).
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
not_cancelled_meta_tasks = set()
|
||||
while self._group_meta_tasks_running:
|
||||
_, meta_tasks = self._group_meta_tasks_running.popitem()
|
||||
not_cancelled_meta_tasks.update(meta_tasks)
|
||||
self.lock()
|
||||
not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
self._group_meta_tasks_running.clear()
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
|
||||
def _generate_group_name(self, prefix: str, coroutine_function: CoroutineFunc) -> str:
|
||||
"""
|
||||
Creates a task group identifier that includes the current datetime.
|
||||
Creates a unique task group identifier.
|
||||
|
||||
Args:
|
||||
prefix: The start of the name; will be followed by an underscore.
|
||||
coroutine_function: The function representing the task group.
|
||||
|
||||
Returns:
|
||||
The constructed 'prefix_function_datetime' string to name a task group.
|
||||
The constructed 'prefix_function_index' string to name a task group.
|
||||
"""
|
||||
return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}'
|
||||
base_name = f'{prefix}_{coroutine_function.__name__}'
|
||||
i = 0
|
||||
while True:
|
||||
name = f'{base_name}_{i}'
|
||||
if name not in self._task_groups.keys():
|
||||
return name
|
||||
i += 1
|
||||
|
||||
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
@ -684,7 +673,9 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
||||
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
|
||||
|
||||
Args:
|
||||
func:
|
||||
@ -696,7 +687,8 @@ class TaskPool(BaseTaskPool):
|
||||
num (optional):
|
||||
The number of tasks to spawn with the specified parameters.
|
||||
group_name (optional):
|
||||
Name of the task group to add the new tasks to.
|
||||
Name of the task group to add the new tasks to. By default, a unique name is constructed using the
|
||||
name of the provided `func` and an incrementing index as 'apply_func_index'.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -705,7 +697,7 @@ class TaskPool(BaseTaskPool):
|
||||
It is run with the task's ID as its only positional argument.
|
||||
|
||||
Returns:
|
||||
The name of the task group that the newly spawned tasks have been added to.
|
||||
The name of the newly created task group (see the `group_name` parameter).
|
||||
|
||||
Raises:
|
||||
`PoolIsClosed`: The pool is closed.
|
||||
@ -717,63 +709,38 @@ class TaskPool(BaseTaskPool):
|
||||
group_name = self._generate_group_name('apply', func)
|
||||
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||
async with group_reg:
|
||||
task = create_task(self._apply_num(group_name, func, args, kwargs, num, end_callback, cancel_callback))
|
||||
await task
|
||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||
return group_name
|
||||
|
||||
@classmethod
|
||||
async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None:
|
||||
"""
|
||||
Keeps the arguments queue from :meth:`_map` full as long as the iterator has elements.
|
||||
|
||||
Intended to be run as a meta task of a specific group.
|
||||
|
||||
Args:
|
||||
arg_queue: The queue of function arguments to consume for starting a new task.
|
||||
arg_iter: The iterator of function arguments to put into the queue.
|
||||
group_name: Name of the task group associated with this producer.
|
||||
"""
|
||||
try:
|
||||
for arg in arg_iter:
|
||||
await arg_queue.put(arg) # This blocks as long as the queue is full.
|
||||
except CancelledError:
|
||||
# This means that no more tasks are supposed to be created from this `_map()` call;
|
||||
# thus, we can immediately drain the entire queue and forget about the rest of the arguments.
|
||||
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
|
||||
while True:
|
||||
try:
|
||||
arg_queue.get_nowait()
|
||||
arg_queue.item_processed()
|
||||
except QueueEmpty:
|
||||
return
|
||||
finally:
|
||||
await arg_queue.put(cls._QUEUE_END_SENTINEL)
|
||||
|
||||
@staticmethod
|
||||
def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
|
||||
"""Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`."""
|
||||
"""Returns a wrapped `end_callback` for each :meth:`_arg_consumer` task that releases the `map_semaphore`."""
|
||||
async def release_callback(task_id: int) -> None:
|
||||
map_semaphore.release()
|
||||
await execute_optional(actual_end_callback, args=(task_id,))
|
||||
return release_callback
|
||||
|
||||
async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
async def _arg_consumer(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT,
|
||||
arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
"""
|
||||
Consumes arguments from the queue from :meth:`_map` and keeps a limited number of tasks working on them.
|
||||
Consumes arguments from :meth:`_map` and keeps a limited number of tasks working on them.
|
||||
|
||||
The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before
|
||||
a new task can be started, and which must be released when one of these tasks ends.
|
||||
`num_concurrent` acts as the limiting value of an internal semaphore, which must be acquired before a new task
|
||||
can be started, and which must be released when one of these tasks ends.
|
||||
|
||||
Intended to be run as a meta task of a specific group.
|
||||
|
||||
Args:
|
||||
arg_queue:
|
||||
The queue of function arguments to consume for starting a new task.
|
||||
group_name:
|
||||
Name of the associated task group; passed into :meth:`_start_task`.
|
||||
num_concurrent:
|
||||
The maximum number new tasks spawned by this method to run concurrently.
|
||||
func:
|
||||
The coroutine function to use for spawning the new tasks within the task pool.
|
||||
arg_iter:
|
||||
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
|
||||
arg_stars (optional):
|
||||
Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2.
|
||||
end_callback (optional):
|
||||
@ -783,28 +750,29 @@ class TaskPool(BaseTaskPool):
|
||||
The callback that was specified to execute after cancellation of the task (and the next one).
|
||||
It is run with the task's ID as its only positional argument.
|
||||
"""
|
||||
map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in :meth:`_map`
|
||||
map_semaphore = Semaphore(num_concurrent)
|
||||
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
|
||||
while True:
|
||||
# The following line blocks **only if** the number of running tasks spawned by this method has reached the
|
||||
# specified maximum as determined in :meth:`_map`.
|
||||
for next_arg in arg_iter:
|
||||
# When the number of running tasks spawned by this method reaches the specified maximum,
|
||||
# this next line will block, until one of them ends and releases the semaphore.
|
||||
await map_semaphore.acquire()
|
||||
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
|
||||
async with arg_queue as next_arg:
|
||||
if next_arg is self._QUEUE_END_SENTINEL:
|
||||
# The :meth:`_queue_producer` either reached the last argument or was cancelled.
|
||||
return
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
log.exception("%s occurred while trying to create task: %s(%s%s)",
|
||||
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
|
||||
map_semaphore.release()
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
# This means that no more tasks are supposed to be created from this `arg_iter`;
|
||||
# thus, we can forget about the rest of the arguments.
|
||||
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
|
||||
map_semaphore.release()
|
||||
return
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
log.exception("%s occurred while trying to create task: %s(%s%s)",
|
||||
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
|
||||
map_semaphore.release()
|
||||
|
||||
async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||
async def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
"""
|
||||
Creates tasks in the pool with arguments from the supplied iterable.
|
||||
@ -813,23 +781,21 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
|
||||
`num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
|
||||
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
||||
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
|
||||
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
||||
of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
|
||||
always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
|
||||
|
||||
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
|
||||
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
|
||||
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
|
||||
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
|
||||
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
|
||||
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
|
||||
|
||||
Args:
|
||||
group_name:
|
||||
Name of the task group to add the new tasks to. It must be a name that doesn't exist yet.
|
||||
group_size:
|
||||
The maximum number new tasks spawned by this method to run concurrently.
|
||||
num_concurrent:
|
||||
The number new tasks spawned by this method to run concurrently.
|
||||
func:
|
||||
The coroutine function to use for spawning the new tasks within the task pool.
|
||||
arg_iter:
|
||||
@ -844,32 +810,21 @@ class TaskPool(BaseTaskPool):
|
||||
It is run with the task's ID as its only positional argument.
|
||||
|
||||
Raises:
|
||||
`ValueError`: `group_size` is less than 1.
|
||||
`ValueError`: `num_concurrent` is less than 1.
|
||||
`asyncio_taskpool.exceptions.InvalidGroupName`: A group named `group_name` exists in the pool.
|
||||
"""
|
||||
self._check_start(function=func)
|
||||
if group_size < 1:
|
||||
raise ValueError(f"Group size must be a positive integer.")
|
||||
if num_concurrent < 1:
|
||||
raise ValueError("`num_concurrent` must be a positive integer.")
|
||||
if group_name in self._task_groups.keys():
|
||||
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
|
||||
self._task_groups[group_name] = group_reg = TaskGroupRegister()
|
||||
async with group_reg:
|
||||
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
|
||||
# `_queue_producer()`; that way an argument
|
||||
arg_queue = Queue(maxsize=group_size)
|
||||
# TODO: This is the wrong thing to await before gathering!
|
||||
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
|
||||
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
|
||||
# the `join` call before the arguments iterator was fully consumed.
|
||||
# Probably the queue producer task should be awaited before gathering instead.
|
||||
self._before_gathering.append(join_queue(arg_queue))
|
||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||
# Start the producer and consumer meta tasks.
|
||||
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
|
||||
meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars,
|
||||
end_callback, cancel_callback)))
|
||||
meta_tasks.add(create_task(self._arg_consumer(group_name, num_concurrent, func, arg_iter, arg_stars,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None,
|
||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||
"""
|
||||
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||
@ -879,27 +834,27 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
All the new tasks are added to the same task group.
|
||||
|
||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
|
||||
`num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
|
||||
any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
|
||||
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
||||
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
|
||||
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
||||
of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
|
||||
always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
|
||||
|
||||
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
|
||||
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
|
||||
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
|
||||
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
|
||||
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
|
||||
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
|
||||
|
||||
Args:
|
||||
func:
|
||||
The coroutine function to use for spawning the new tasks within the task pool.
|
||||
arg_iter:
|
||||
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
|
||||
group_size (optional):
|
||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
num_concurrent (optional):
|
||||
The number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||
group_name (optional):
|
||||
Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
|
||||
By default, a unique name is constructed using the name of the provided `func` and an incrementing
|
||||
index as 'apply_func_index'.
|
||||
end_callback (optional):
|
||||
A callback to execute after a task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
@ -908,22 +863,22 @@ class TaskPool(BaseTaskPool):
|
||||
It is run with the task's ID as its only positional argument.
|
||||
|
||||
Returns:
|
||||
The name of the task group that the newly spawned tasks will be added to.
|
||||
The name of the newly created task group (see the `group_name` parameter).
|
||||
|
||||
Raises:
|
||||
`PoolIsClosed`: The pool is closed.
|
||||
`NotCoroutine`: `func` is not a coroutine function.
|
||||
`PoolIsLocked`: The pool is currently locked.
|
||||
`ValueError`: `group_size` is less than 1.
|
||||
`ValueError`: `num_concurrent` is less than 1.
|
||||
`InvalidGroupName`: A group named `group_name` exists in the pool.
|
||||
"""
|
||||
if group_name is None:
|
||||
group_name = self._generate_group_name('map', func)
|
||||
await self._map(group_name, group_size, func, arg_iter, 0,
|
||||
await self._map(group_name, num_concurrent, func, arg_iter, 0,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
return group_name
|
||||
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1,
|
||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||
"""
|
||||
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
||||
@ -932,11 +887,11 @@ class TaskPool(BaseTaskPool):
|
||||
"""
|
||||
if group_name is None:
|
||||
group_name = self._generate_group_name('starmap', func)
|
||||
await self._map(group_name, group_size, func, args_iter, 1,
|
||||
await self._map(group_name, num_concurrent, func, args_iter, 1,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
return group_name
|
||||
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
||||
group_name: str = None, end_callback: EndCB = None,
|
||||
cancel_callback: CancelCB = None) -> str:
|
||||
"""
|
||||
@ -946,7 +901,7 @@ class TaskPool(BaseTaskPool):
|
||||
"""
|
||||
if group_name is None:
|
||||
group_name = self._generate_group_name('doublestarmap', func)
|
||||
await self._map(group_name, group_size, func, kwargs_iter, 2,
|
||||
await self._map(group_name, num_concurrent, func, kwargs_iter, 2,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
return group_name
|
||||
|
||||
|
@ -35,7 +35,7 @@ from asyncio_taskpool.internals.types import ArgsT, CancelCB, CoroutineFunc, End
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class ControlServerTestCase(TestCase):
|
||||
class ControlParserTestCase(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
|
||||
@ -265,12 +265,36 @@ class ControlServerTestCase(TestCase):
|
||||
|
||||
|
||||
class RestTestCase(TestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = parser.log.level
|
||||
parser.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
parser.log.setLevel(cls.log_lvl)
|
||||
|
||||
def test__get_arg_type_wrapper(self):
|
||||
type_wrap = parser._get_arg_type_wrapper(int)
|
||||
self.assertEqual('int', type_wrap.__name__)
|
||||
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
|
||||
self.assertEqual(13, type_wrap('13'))
|
||||
|
||||
name = 'abcdef'
|
||||
mock_type = MagicMock(side_effect=[parser.ArgumentTypeError, TypeError, ValueError, Exception], __name__=name)
|
||||
type_wrap = parser._get_arg_type_wrapper(mock_type)
|
||||
self.assertEqual(name, type_wrap.__name__)
|
||||
with self.assertRaises(parser.ArgumentTypeError):
|
||||
type_wrap(FOO)
|
||||
with self.assertRaises(TypeError):
|
||||
type_wrap(FOO)
|
||||
with self.assertRaises(ValueError):
|
||||
type_wrap(FOO)
|
||||
with self.assertRaises(parser.ArgumentTypeError):
|
||||
type_wrap(FOO)
|
||||
|
||||
@patch.object(parser, '_get_arg_type_wrapper')
|
||||
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
|
||||
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR
|
||||
|
@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
helpers.star_function(f, a, 123456789)
|
||||
|
||||
async def test_join_queue(self):
|
||||
mock_join = AsyncMock()
|
||||
mock_queue = MagicMock(join=mock_join)
|
||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
||||
mock_join.assert_awaited_once_with()
|
||||
|
||||
def test_get_first_doc_line(self):
|
||||
expected_output = 'foo bar baz'
|
||||
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
||||
|
@ -20,14 +20,11 @@ Unittests for the `asyncio_taskpool.pool` module.
|
||||
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.locks import Semaphore
|
||||
from asyncio.queues import QueueEmpty
|
||||
from datetime import datetime
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||
from typing import Type
|
||||
|
||||
from asyncio_taskpool import pool, exceptions
|
||||
from asyncio_taskpool.internals.constants import DATETIME_FORMAT
|
||||
|
||||
|
||||
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
|
||||
@ -93,7 +90,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||
|
||||
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
||||
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||
|
||||
@ -365,32 +361,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
|
||||
async def test_gather_and_close(self):
|
||||
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
|
||||
@patch.object(pool.BaseTaskPool, 'lock')
|
||||
async def test_gather_and_close(self, mock_lock: MagicMock):
|
||||
mock_running_func = AsyncMock()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
|
||||
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
|
||||
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._tasks_running = running = {789: mock_running_func()}
|
||||
|
||||
with self.assertRaises(exceptions.PoolStillUnlocked):
|
||||
await self.task_pool.gather_and_close()
|
||||
self.assertDictEqual(ended, self.task_pool._tasks_ended)
|
||||
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(running, self.task_pool._tasks_running)
|
||||
self.assertListEqual(before_gather, self.task_pool._before_gathering)
|
||||
self.assertFalse(self.task_pool._closed)
|
||||
self.task_pool._tasks_ended = {123: mock_ended_func()}
|
||||
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._tasks_running = {789: mock_running_func()}
|
||||
|
||||
self.task_pool._locked = True
|
||||
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||
mock_before_gather.assert_awaited_once_with()
|
||||
mock_lock.assert_called_once_with()
|
||||
mock_ended_func.assert_awaited_once_with()
|
||||
mock_cancelled_func.assert_awaited_once_with()
|
||||
mock_running_func.assert_awaited_once_with()
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
|
||||
self.assertTrue(self.task_pool._closed)
|
||||
|
||||
|
||||
@ -471,13 +458,15 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock_cancelled_meta_task.assert_awaited_once_with()
|
||||
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||
|
||||
@patch.object(pool.BaseTaskPool, 'lock')
|
||||
@patch.object(pool.BaseTaskPool, 'gather_and_close')
|
||||
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
|
||||
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock, mock_lock: MagicMock):
|
||||
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
|
||||
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
|
||||
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
||||
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
||||
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||
mock_lock.assert_called_once_with()
|
||||
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
|
||||
mock_meta_task1.assert_awaited_once_with()
|
||||
mock_meta_task2.assert_awaited_once_with()
|
||||
@ -485,13 +474,15 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||
|
||||
@patch.object(pool, 'datetime')
|
||||
def test__generate_group_name(self, mock_datetime: MagicMock):
|
||||
def test__generate_group_name(self):
|
||||
prefix, func = 'x y z', AsyncMock(__name__=BAR)
|
||||
dt = datetime(1776, 7, 4, 0, 0, 1)
|
||||
mock_datetime.now = MagicMock(return_value=dt)
|
||||
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
|
||||
output = pool.TaskPool._generate_group_name(prefix, func)
|
||||
self.task_pool._task_groups = {
|
||||
f'{prefix}_{BAR}_0': MagicMock(),
|
||||
f'{prefix}_{BAR}_1': MagicMock(),
|
||||
f'{prefix}_{BAR}_100': MagicMock(),
|
||||
}
|
||||
expected_output = f'{prefix}_{BAR}_2'
|
||||
output = self.task_pool._generate_group_name(prefix, func)
|
||||
self.assertEqual(expected_output, output)
|
||||
|
||||
@patch.object(pool.TaskPool, '_start_task')
|
||||
@ -526,8 +517,7 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock__generate_group_name.return_value = generated_name = 'name 123'
|
||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||
mock__apply_num.return_value = mock_apply_coroutine = object()
|
||||
mock_task_future = AsyncMock()
|
||||
mock_create_task.return_value = mock_task_future()
|
||||
mock_create_task.return_value = fake_task = object()
|
||||
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
|
||||
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
@ -538,10 +528,11 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock__check_start.assert_called_once_with(function=mock_func)
|
||||
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
|
||||
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
|
||||
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_apply_coroutine)
|
||||
mock_group_reg.__aexit__.assert_awaited_once()
|
||||
mock_task_future.assert_awaited_once_with()
|
||||
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
|
||||
|
||||
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
||||
check_assertions(group_name, output)
|
||||
@ -553,35 +544,11 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock__apply_num.reset_mock()
|
||||
mock_create_task.reset_mock()
|
||||
mock_group_reg.__aexit__.reset_mock()
|
||||
mock_task_future = AsyncMock()
|
||||
mock_create_task.return_value = mock_task_future()
|
||||
|
||||
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
|
||||
check_assertions(generated_name, output)
|
||||
mock__generate_group_name.assert_called_once_with('apply', mock_func)
|
||||
|
||||
@patch.object(pool, 'Queue')
|
||||
async def test__queue_producer(self, mock_queue_cls: MagicMock):
|
||||
mock_put = AsyncMock()
|
||||
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
|
||||
item1, item2, item3 = FOO, 420, 69
|
||||
arg_iter = iter([item1, item2, item3])
|
||||
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
||||
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
||||
with self.assertRaises(StopIteration):
|
||||
next(arg_iter)
|
||||
|
||||
mock_put.reset_mock()
|
||||
|
||||
mock_put.side_effect = [CancelledError, None]
|
||||
arg_iter = iter([item1, item2, item3])
|
||||
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
|
||||
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
||||
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
||||
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
|
||||
mock_queue.item_processed.assert_has_calls([call(), call()])
|
||||
self.assertListEqual([item2, item3], list(arg_iter))
|
||||
|
||||
@patch.object(pool, 'execute_optional')
|
||||
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
|
||||
semaphore, mock_end_cb = Semaphore(1), MagicMock()
|
||||
@ -597,84 +564,85 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
@patch.object(pool, 'Semaphore')
|
||||
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
|
||||
mock__start_task: AsyncMock, mock_star_function: MagicMock):
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
|
||||
n = 2
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
|
||||
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||
awaitable = 'totally an awaitable'
|
||||
mock_star_function.side_effect = [awaitable, awaitable, Exception()]
|
||||
mock_star_function.side_effect = [awaitable, Exception(), awaitable]
|
||||
arg1, arg2, bad = 123456789, 'function argument', None
|
||||
mock_q_maxsize = 3
|
||||
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]),
|
||||
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
|
||||
args = [arg1, bad, arg2]
|
||||
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
|
||||
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then
|
||||
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3,
|
||||
# at the end of the loop, we expect it be locked.
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then
|
||||
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
|
||||
self.assertTrue(semaphore.locked())
|
||||
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
|
||||
mock_semaphore_cls.assert_called_once_with(n)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_has_awaits(2 * [
|
||||
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||
])
|
||||
mock_star_function.assert_has_calls([
|
||||
call(mock_func, arg1, arg_stars=stars),
|
||||
call(mock_func, arg2, arg_stars=stars),
|
||||
call(mock_func, bad, arg_stars=stars)
|
||||
call(mock_func, bad, arg_stars=stars),
|
||||
call(mock_func, arg2, arg_stars=stars)
|
||||
])
|
||||
|
||||
mock_semaphore_cls.reset_mock()
|
||||
mock__get_map_end_callback.reset_mock()
|
||||
mock__start_task.reset_mock()
|
||||
mock_star_function.reset_mock()
|
||||
|
||||
# With a CancelledError thrown while starting a task:
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
|
||||
mock_star_function.side_effect = CancelledError()
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||
self.assertFalse(semaphore.locked())
|
||||
mock_semaphore_cls.assert_called_once_with(n)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_not_called()
|
||||
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
|
||||
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
|
||||
@patch.object(pool, 'join_queue', new_callable=MagicMock)
|
||||
@patch.object(pool, 'Queue')
|
||||
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
|
||||
@patch.object(pool, 'TaskGroupRegister')
|
||||
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
|
||||
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
|
||||
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock,
|
||||
mock_create_task: MagicMock):
|
||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||
mock_queue_cls.return_value = mock_q = MagicMock()
|
||||
mock_join_queue.return_value = fake_join = object()
|
||||
mock__queue_producer.return_value = fake_producer = object()
|
||||
mock__queue_consumer.return_value = fake_consumer = object()
|
||||
fake_task1, fake_task2 = object(), object()
|
||||
mock_create_task.side_effect = [fake_task1, fake_task2]
|
||||
mock__arg_consumer.return_value = fake_consumer = object()
|
||||
mock_create_task.return_value = fake_task = object()
|
||||
|
||||
group_name, group_size = 'onetwothree', 0
|
||||
group_name, n = 'onetwothree', 0
|
||||
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
||||
await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
|
||||
mock__check_start.assert_called_once_with(function=func)
|
||||
|
||||
mock__check_start.reset_mock()
|
||||
|
||||
group_size = 1234
|
||||
n = 1234
|
||||
self.task_pool._task_groups = {group_name: MagicMock()}
|
||||
|
||||
with self.assertRaises(exceptions.InvalidGroupName):
|
||||
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
||||
await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
|
||||
mock__check_start.assert_called_once_with(function=func)
|
||||
|
||||
mock__check_start.reset_mock()
|
||||
|
||||
self.task_pool._task_groups.clear()
|
||||
self.task_pool._before_gathering = []
|
||||
|
||||
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
|
||||
self.assertIsNone(await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb))
|
||||
mock__check_start.assert_called_once_with(function=func)
|
||||
mock_reg_cls.assert_called_once_with()
|
||||
self.task_pool._task_groups[group_name] = mock_group_reg
|
||||
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||
mock_queue_cls.assert_called_once_with(maxsize=group_size)
|
||||
mock_join_queue.assert_called_once_with(mock_q)
|
||||
self.assertListEqual([fake_join], self.task_pool._before_gathering)
|
||||
mock__queue_producer.assert_called_once()
|
||||
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
|
||||
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
|
||||
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
|
||||
mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock_create_task.assert_called_once_with(fake_consumer)
|
||||
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
|
||||
mock_group_reg.__aexit__.assert_awaited_once()
|
||||
|
||||
@patch.object(pool.TaskPool, '_map')
|
||||
@ -682,18 +650,18 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||
mock_func = MagicMock()
|
||||
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
|
||||
arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
|
||||
output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||
self.assertEqual(group_name, output)
|
||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
|
||||
mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, arg_iter, 0,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_not_called()
|
||||
|
||||
mock__map.reset_mock()
|
||||
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
|
||||
output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||
self.assertEqual(generated_name, output)
|
||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
|
||||
mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_called_once_with('map', mock_func)
|
||||
|
||||
@ -702,18 +670,18 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||
mock_func = MagicMock()
|
||||
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
|
||||
args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
|
||||
output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||
self.assertEqual(group_name, output)
|
||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
|
||||
mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, args_iter, 1,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_not_called()
|
||||
|
||||
mock__map.reset_mock()
|
||||
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
|
||||
output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||
self.assertEqual(generated_name, output)
|
||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
|
||||
mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, args_iter, 1,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
|
||||
|
||||
@ -722,18 +690,18 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||
mock_func = MagicMock()
|
||||
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
|
||||
kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
|
||||
output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||
self.assertEqual(group_name, output)
|
||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
|
||||
mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, kw_iter, 2,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_not_called()
|
||||
|
||||
mock__map.reset_mock()
|
||||
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
|
||||
output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||
self.assertEqual(generated_name, output)
|
||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
|
||||
mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
|
||||
|
||||
|
@ -41,10 +41,9 @@ async def main() -> None:
|
||||
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
|
||||
await pool.start(3) # launches work tasks 0, 1, and 2
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
await pool.start() # launches work task 3
|
||||
await pool.start(1) # launches work task 3
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
|
||||
pool.lock() # required for the last line
|
||||
await pool.gather_and_close() # awaits all tasks, then flushes the pool
|
||||
|
||||
|
||||
@ -135,11 +134,9 @@ async def main() -> None:
|
||||
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
||||
# only once there is room in the pool and no more than one other task of these new ones is running.
|
||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||
await pool.starmap(other_work, args_list, group_size=2)
|
||||
await pool.starmap(other_work, args_list, num_concurrent=2)
|
||||
print("> Called `starmap`")
|
||||
# Now we lock the pool, so that we can safely await all our tasks.
|
||||
pool.lock()
|
||||
# Finally, we block, until all tasks have ended.
|
||||
# We block, until all tasks have ended.
|
||||
print("> Calling `gather_and_close`...")
|
||||
await pool.gather_and_close()
|
||||
print("> Done.")
|
||||
@ -199,7 +196,7 @@ Started TaskPool-0_Task-3
|
||||
> other_work with 15
|
||||
Ended TaskPool-0_Task-0
|
||||
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
|
||||
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
|
||||
Started TaskPool-0_Task-4 <--- since `num_concurrent` is set to 2, Task-5 will not start
|
||||
> work with 190
|
||||
> work with 190
|
||||
> other_work with 16
|
||||
|
@ -75,7 +75,6 @@ async def main() -> None:
|
||||
control_server_task.cancel()
|
||||
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
||||
# we can now safely cancel their tasks.
|
||||
pool.lock()
|
||||
pool.stop_all()
|
||||
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
|
||||
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
||||
|
Reference in New Issue
Block a user