Compare commits

...

6 Commits

12 changed files with 193 additions and 264 deletions

View File

@ -43,7 +43,6 @@ async def main():
... ...
pool.stop(3) pool.stop(3)
... ...
pool.lock()
await pool.gather_and_close() await pool.gather_and_close()
... ...
``` ```

View File

@ -96,9 +96,9 @@ When you are dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.T
.. code-block:: none .. code-block:: none
> map mypackage.mymodule.worker ['x','y','z'] -g 3 > map mypackage.mymodule.worker ['x','y','z'] -n 3
The :code:`-g` is a shorthand for :code:`--group-size` in this case. In general, all (public) pool methods will have a corresponding command in the control session. The :code:`-n` is a shorthand for :code:`--num-concurrent` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
.. note:: .. note::

View File

@ -46,7 +46,7 @@ Let's take a look at an example. Say you have a coroutine function that takes tw
async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None: async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None:
while True: while True:
item = await in_queue.get() item = await in_queue.get()
... # Do some work on the item amd arrive at a result. ... # Do some work on the item and arrive at a result.
await out_queue.put(result) await out_queue.put(result)
How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this: How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this:
@ -141,9 +141,8 @@ Or we could use a task pool:
async def main(): async def main():
... ...
pool = TaskPool() pool = TaskPool()
await pool.map(another_worker_function, data_iterator, group_size=5) await pool.map(another_worker_function, data_iterator, num_concurrent=5)
... ...
pool.lock()
await pool.gather_and_close() await pool.gather_and_close()
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction). Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
@ -229,7 +228,6 @@ The only method of a pool that one should **always** assume to be blocking is :p
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations. One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, both :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` and :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased). In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased).
:py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since it makes use of "meta-tasks" under the hood, it will always return immediately. However, if the pool was full when it was called, there is **no guarantee** that even a single task has started, when the method returns.
:py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since they make use of "meta-tasks" under the hood, they will always return immediately. However, if the pool was full when one of them was called, there is **no guarantee** that even a single task has started, when the method returns.

View File

@ -20,7 +20,8 @@ Definition of the :class:`ControlParser` used in a
""" """
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS import logging
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
from ast import literal_eval from ast import literal_eval
from asyncio.streams import StreamWriter from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature from inspect import Parameter, getmembers, isfunction, signature
@ -36,6 +37,9 @@ from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
__all__ = ['ControlParser'] __all__ = ['ControlParser']
log = logging.getLogger(__name__)
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
ParsersDict = Dict[str, 'ControlParser'] ParsersDict = Dict[str, 'ControlParser']
@ -300,8 +304,21 @@ def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments. Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
See: https://bugs.python.org/issue36078 See: https://bugs.python.org/issue36078
In addition, the type conversion wrapper catches exceptions not handled properly by the parser, logs them, and
turns them into `ArgumentTypeError` exceptions the parser can propagate to the client.
""" """
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg) def wrapper(arg: Any) -> Any:
if arg is SUPPRESS:
return arg
try:
return cls(arg)
except (ArgumentTypeError, TypeError, ValueError):
raise # handled properly by the parser and propagated to the client anyway
except Exception as e:
text = f"{e.__class__.__name__} occurred in parser trying to convert type: {cls.__name__}({repr(arg)})"
log.exception(text)
raise ArgumentTypeError(text) # propagate to the client
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed. # Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
wrapper.__name__ = cls.__name__ wrapper.__name__ = cls.__name__
return wrapper return wrapper

View File

@ -51,10 +51,6 @@ class InvalidGroupName(PoolException):
pass pass
class PoolStillUnlocked(PoolException):
pass
class NotCoroutine(PoolException): class NotCoroutine(PoolException):
pass pass

View File

@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
await q.join()
def get_first_doc_line(obj: object) -> str: def get_first_doc_line(obj: object) -> str:
"""Takes an object and returns the first (non-empty) line of its docstring.""" """Takes an object and returns the first (non-empty) line of its docstring."""
return getdoc(obj).strip().split("\n", 1)[0].strip() return getdoc(obj).strip().split("\n", 1)[0].strip()

View File

@ -31,18 +31,16 @@ import logging
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
from asyncio.queues import QueueEmpty
from asyncio.tasks import Task, create_task, gather from asyncio.tasks import Task, create_task, gather
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from math import inf from math import inf
from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set, Union from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions from . import exceptions
from .queue_context import Queue
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function, join_queue from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -82,8 +80,7 @@ class BaseTaskPool:
self._tasks_cancelled: Dict[int, Task] = {} self._tasks_cancelled: Dict[int, Task] = {}
self._tasks_ended: Dict[int, Task] = {} self._tasks_ended: Dict[int, Task] = {}
# These next three attributes act as synchronisation primitives necessary for managing the pool. # These next two attributes act as synchronisation primitives necessary for managing the pool.
self._before_gathering: List[Awaitable] = []
self._enough_room: Semaphore = Semaphore() self._enough_room: Semaphore = Semaphore()
self._task_groups: Dict[str, TaskGroupRegister[int]] = {} self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
@ -453,9 +450,7 @@ class BaseTaskPool:
""" """
Gathers (i.e. awaits) **all** tasks in the pool, then closes it. Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
After this method returns, no more tasks can be started in the pool. Once this method is called, no more tasks can be started in the pool.
:meth:`lock` must have been called prior to this.
This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the
callbacks registered for a task blocks for whatever reason. callbacks registered for a task blocks for whatever reason.
@ -466,15 +461,12 @@ class BaseTaskPool:
Raises: Raises:
`PoolStillUnlocked`: The pool has not been locked yet. `PoolStillUnlocked`: The pool has not been locked yet.
""" """
if not self._locked: self.lock()
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering)
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(), await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._tasks_ended.clear() self._tasks_ended.clear()
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._before_gathering.clear()
self._closed = True self._closed = True
@ -494,12 +486,10 @@ class TaskPool(BaseTaskPool):
Adding tasks blocks **only if** the pool is full at that moment. Adding tasks blocks **only if** the pool is full at that moment.
""" """
_QUEUE_END_SENTINEL = object()
def __init__(self, pool_size: int = inf, name: str = None) -> None: def __init__(self, pool_size: int = inf, name: str = None) -> None:
super().__init__(pool_size=pool_size, name=name) super().__init__(pool_size=pool_size, name=name)
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks. # meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
self._group_meta_tasks_running: Dict[str, Set[Task]] = {} self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
self._meta_tasks_cancelled: Set[Task] = set() self._meta_tasks_cancelled: Set[Task] = set()
@ -592,24 +582,21 @@ class TaskPool(BaseTaskPool):
Args: Args:
return_exceptions (optional): Passed directly into `gather`. return_exceptions (optional): Passed directly into `gather`.
""" """
await super().flush(return_exceptions=return_exceptions)
with suppress(CancelledError): with suppress(CancelledError):
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(), await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._meta_tasks_cancelled.clear() self._meta_tasks_cancelled.clear()
await super().flush(return_exceptions=return_exceptions)
async def gather_and_close(self, return_exceptions: bool = False): async def gather_and_close(self, return_exceptions: bool = False):
""" """
Gathers (i.e. awaits) **all** tasks in the pool, then closes it. Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
After this method returns, no more tasks can be started in the pool. Once this method is called, no more tasks can be started in the pool.
The `lock()` method must have been called prior to this.
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed, tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
this, make sure to call :meth:`cancel_all` prior to this.
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
the callbacks registered for a task blocks for whatever reason. the callbacks registered for a task blocks for whatever reason.
@ -620,15 +607,12 @@ class TaskPool(BaseTaskPool):
Raises: Raises:
`PoolStillUnlocked`: The pool has not been locked yet. `PoolStillUnlocked`: The pool has not been locked yet.
""" """
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`). not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
await super().gather_and_close(return_exceptions=return_exceptions)
not_cancelled_meta_tasks = set()
while self._group_meta_tasks_running:
_, meta_tasks = self._group_meta_tasks_running.popitem()
not_cancelled_meta_tasks.update(meta_tasks)
with suppress(CancelledError): with suppress(CancelledError):
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions) await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
self._meta_tasks_cancelled.clear() self._meta_tasks_cancelled.clear()
self._group_meta_tasks_running.clear()
await super().gather_and_close(return_exceptions=return_exceptions)
@staticmethod @staticmethod
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str: def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
@ -684,7 +668,9 @@ class TaskPool(BaseTaskPool):
All the new tasks are added to the same task group. All the new tasks are added to the same task group.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks. Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
because this method returns immediately, this does not mean that any task was started or that any number of
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
Args: Args:
func: func:
@ -717,38 +703,11 @@ class TaskPool(BaseTaskPool):
group_name = self._generate_group_name('apply', func) group_name = self._generate_group_name('apply', func)
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
async with group_reg: async with group_reg:
task = create_task(self._apply_num(group_name, func, args, kwargs, num, end_callback, cancel_callback)) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
await task meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
return group_name return group_name
@classmethod
async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None:
"""
Keeps the arguments queue from :meth:`_map` full as long as the iterator has elements.
Intended to be run as a meta task of a specific group.
Args:
arg_queue: The queue of function arguments to consume for starting a new task.
arg_iter: The iterator of function arguments to put into the queue.
group_name: Name of the task group associated with this producer.
"""
try:
for arg in arg_iter:
await arg_queue.put(arg) # This blocks as long as the queue is full.
except CancelledError:
# This means that no more tasks are supposed to be created from this `_map()` call;
# thus, we can immediately drain the entire queue and forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
while True:
try:
arg_queue.get_nowait()
arg_queue.item_processed()
except QueueEmpty:
return
finally:
await arg_queue.put(cls._QUEUE_END_SENTINEL)
@staticmethod @staticmethod
def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
"""Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`.""" """Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`."""
@ -757,23 +716,25 @@ class TaskPool(BaseTaskPool):
await execute_optional(actual_end_callback, args=(task_id,)) await execute_optional(actual_end_callback, args=(task_id,))
return release_callback return release_callback
async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0, async def _arg_consumer(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
""" """
Consumes arguments from the queue from :meth:`_map` and keeps a limited number of tasks working on them. Consumes arguments from :meth:`_map` and keeps a limited number of tasks working on them.
The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before `num_concurrent` acts as the limiting value of an internal semaphore, which must be acquired before a new task
a new task can be started, and which must be released when one of these tasks ends. can be started, and which must be released when one of these tasks ends.
Intended to be run as a meta task of a specific group. Intended to be run as a meta task of a specific group.
Args: Args:
arg_queue:
The queue of function arguments to consume for starting a new task.
group_name: group_name:
Name of the associated task group; passed into :meth:`_start_task`. Name of the associated task group; passed into :meth:`_start_task`.
num_concurrent:
The maximum number new tasks spawned by this method to run concurrently.
func: func:
The coroutine function to use for spawning the new tasks within the task pool. The coroutine function to use for spawning the new tasks within the task pool.
arg_iter:
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional): arg_stars (optional):
Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2. Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2.
end_callback (optional): end_callback (optional):
@ -783,20 +744,21 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one). The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
""" """
map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in :meth:`_map` map_semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
while True: for next_arg in arg_iter:
# The following line blocks **only if** the number of running tasks spawned by this method has reached the # When the number of running tasks spawned by this method reaches the specified maximum,
# specified maximum as determined in :meth:`_map`. # this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire() await map_semaphore.acquire()
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
async with arg_queue as next_arg:
if next_arg is self._QUEUE_END_SENTINEL:
# The :meth:`_queue_producer` either reached the last argument or was cancelled.
return
try: try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# This means that no more tasks are supposed to be created from this `arg_iter`;
# thus, we can forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
map_semaphore.release()
return
except Exception as e: except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created. # This means an exception occurred during task **creation**, meaning no task has been created.
# It does not imply an error within the task itself. # It does not imply an error within the task itself.
@ -804,7 +766,7 @@ class TaskPool(BaseTaskPool):
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
map_semaphore.release() map_semaphore.release()
async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, async def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
""" """
Creates tasks in the pool with arguments from the supplied iterable. Creates tasks in the pool with arguments from the supplied iterable.
@ -813,23 +775,21 @@ class TaskPool(BaseTaskPool):
All the new tasks are added to the same task group. All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at `num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the because this method returns immediately, this does not mean that any task was started or that any number of
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
Args: Args:
group_name: group_name:
Name of the task group to add the new tasks to. It must be a name that doesn't exist yet. Name of the task group to add the new tasks to. It must be a name that doesn't exist yet.
group_size: num_concurrent:
The maximum number new tasks spawned by this method to run concurrently. The number new tasks spawned by this method to run concurrently.
func: func:
The coroutine function to use for spawning the new tasks within the task pool. The coroutine function to use for spawning the new tasks within the task pool.
arg_iter: arg_iter:
@ -844,32 +804,21 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
Raises: Raises:
`ValueError`: `group_size` is less than 1. `ValueError`: `num_concurrent` is less than 1.
`asyncio_taskpool.exceptions.InvalidGroupName`: A group named `group_name` exists in the pool. `asyncio_taskpool.exceptions.InvalidGroupName`: A group named `group_name` exists in the pool.
""" """
self._check_start(function=func) self._check_start(function=func)
if group_size < 1: if num_concurrent < 1:
raise ValueError(f"Group size must be a positive integer.") raise ValueError("`num_concurrent` must be a positive integer.")
if group_name in self._task_groups.keys(): if group_name in self._task_groups.keys():
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups[group_name] = group_reg = TaskGroupRegister() self._task_groups[group_name] = group_reg = TaskGroupRegister()
async with group_reg: async with group_reg:
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
# `_queue_producer()`; that way an argument
arg_queue = Queue(maxsize=group_size)
# TODO: This is the wrong thing to await before gathering!
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
# the `join` call before the arguments iterator was fully consumed.
# Probably the queue producer task should be awaited before gathering instead.
self._before_gathering.append(join_queue(arg_queue))
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
# Start the producer and consumer meta tasks. meta_tasks.add(create_task(self._arg_consumer(group_name, num_concurrent, func, arg_iter, arg_stars,
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) end_callback=end_callback, cancel_callback=cancel_callback)))
meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars,
end_callback, cancel_callback)))
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None, async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
""" """
A task-based equivalent of the `multiprocessing.pool.Pool.map` method. A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
@ -879,25 +828,23 @@ class TaskPool(BaseTaskPool):
All the new tasks are added to the same task group. All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at `num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the because this method returns immediately, this does not mean that any task was started or that any number of
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
Args: Args:
func: func:
The coroutine function to use for spawning the new tasks within the task pool. The coroutine function to use for spawning the new tasks within the task pool.
arg_iter: arg_iter:
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task. The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
group_size (optional): num_concurrent (optional):
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1. The number new tasks spawned by this method to run concurrently. Defaults to 1.
group_name (optional): group_name (optional):
Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet. Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
end_callback (optional): end_callback (optional):
@ -914,16 +861,16 @@ class TaskPool(BaseTaskPool):
`PoolIsClosed`: The pool is closed. `PoolIsClosed`: The pool is closed.
`NotCoroutine`: `func` is not a coroutine function. `NotCoroutine`: `func` is not a coroutine function.
`PoolIsLocked`: The pool is currently locked. `PoolIsLocked`: The pool is currently locked.
`ValueError`: `group_size` is less than 1. `ValueError`: `num_concurrent` is less than 1.
`InvalidGroupName`: A group named `group_name` exists in the pool. `InvalidGroupName`: A group named `group_name` exists in the pool.
""" """
if group_name is None: if group_name is None:
group_name = self._generate_group_name('map', func) group_name = self._generate_group_name('map', func)
await self._map(group_name, group_size, func, arg_iter, 0, await self._map(group_name, num_concurrent, func, arg_iter, 0,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
return group_name return group_name
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1, async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1,
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
""" """
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
@ -932,11 +879,11 @@ class TaskPool(BaseTaskPool):
""" """
if group_name is None: if group_name is None:
group_name = self._generate_group_name('starmap', func) group_name = self._generate_group_name('starmap', func)
await self._map(group_name, group_size, func, args_iter, 1, await self._map(group_name, num_concurrent, func, args_iter, 1,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
return group_name return group_name
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1, async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
group_name: str = None, end_callback: EndCB = None, group_name: str = None, end_callback: EndCB = None,
cancel_callback: CancelCB = None) -> str: cancel_callback: CancelCB = None) -> str:
""" """
@ -946,7 +893,7 @@ class TaskPool(BaseTaskPool):
""" """
if group_name is None: if group_name is None:
group_name = self._generate_group_name('doublestarmap', func) group_name = self._generate_group_name('doublestarmap', func)
await self._map(group_name, group_size, func, kwargs_iter, 2, await self._map(group_name, num_concurrent, func, kwargs_iter, 2,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
return group_name return group_name

View File

@ -35,7 +35,7 @@ from asyncio_taskpool.internals.types import ArgsT, CancelCB, CoroutineFunc, End
FOO, BAR = 'foo', 'bar' FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(TestCase): class ControlParserTestCase(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
@ -265,12 +265,36 @@ class ControlServerTestCase(TestCase):
class RestTestCase(TestCase): class RestTestCase(TestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = parser.log.level
parser.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
parser.log.setLevel(cls.log_lvl)
def test__get_arg_type_wrapper(self): def test__get_arg_type_wrapper(self):
type_wrap = parser._get_arg_type_wrapper(int) type_wrap = parser._get_arg_type_wrapper(int)
self.assertEqual('int', type_wrap.__name__) self.assertEqual('int', type_wrap.__name__)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS)) self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13')) self.assertEqual(13, type_wrap('13'))
name = 'abcdef'
mock_type = MagicMock(side_effect=[parser.ArgumentTypeError, TypeError, ValueError, Exception], __name__=name)
type_wrap = parser._get_arg_type_wrapper(mock_type)
self.assertEqual(name, type_wrap.__name__)
with self.assertRaises(parser.ArgumentTypeError):
type_wrap(FOO)
with self.assertRaises(TypeError):
type_wrap(FOO)
with self.assertRaises(ValueError):
type_wrap(FOO)
with self.assertRaises(parser.ArgumentTypeError):
type_wrap(FOO)
@patch.object(parser, '_get_arg_type_wrapper') @patch.object(parser, '_get_arg_type_wrapper')
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock): def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR

View File

@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
helpers.star_function(f, a, 123456789) helpers.star_function(f, a, 123456789)
async def test_join_queue(self):
mock_join = AsyncMock()
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_get_first_doc_line(self): def test_get_first_doc_line(self):
expected_output = 'foo bar baz' expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output} mock_obj = MagicMock(__doc__=f"""{expected_output}

View File

@ -20,7 +20,6 @@ Unittests for the `asyncio_taskpool.pool` module.
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
from asyncio.queues import QueueEmpty
from datetime import datetime from datetime import datetime
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
@ -93,7 +92,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._enough_room, Semaphore) self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
@ -366,31 +364,20 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
async def test_gather_and_close(self): async def test_gather_and_close(self):
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock() mock_running_func = AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._before_gathering = before_gather = [mock_before_gather()] self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._tasks_ended = ended = {123: mock_ended_func()} self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()} self.task_pool._tasks_running = {789: mock_running_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()}
with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather_and_close()
self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_before_gather.assert_awaited_once_with()
mock_ended_func.assert_awaited_once_with() mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with() mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with() mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._closed) self.assertTrue(self.task_pool._closed)
@ -526,8 +513,7 @@ class TaskPoolTestCase(CommonTestCase):
mock__generate_group_name.return_value = generated_name = 'name 123' mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object() mock__apply_num.return_value = mock_apply_coroutine = object()
mock_task_future = AsyncMock() mock_create_task.return_value = fake_task = object()
mock_create_task.return_value = mock_task_future()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
@ -538,10 +524,11 @@ class TaskPoolTestCase(CommonTestCase):
mock__check_start.assert_called_once_with(function=mock_func) mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock_group_reg.__aenter__.assert_awaited_once_with() mock_group_reg.__aenter__.assert_awaited_once_with()
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb) mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine) mock_create_task.assert_called_once_with(mock_apply_coroutine)
mock_group_reg.__aexit__.assert_awaited_once() mock_group_reg.__aexit__.assert_awaited_once()
mock_task_future.assert_awaited_once_with() self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
check_assertions(group_name, output) check_assertions(group_name, output)
@ -553,35 +540,11 @@ class TaskPoolTestCase(CommonTestCase):
mock__apply_num.reset_mock() mock__apply_num.reset_mock()
mock_create_task.reset_mock() mock_create_task.reset_mock()
mock_group_reg.__aexit__.reset_mock() mock_group_reg.__aexit__.reset_mock()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
check_assertions(generated_name, output) check_assertions(generated_name, output)
mock__generate_group_name.assert_called_once_with('apply', mock_func) mock__generate_group_name.assert_called_once_with('apply', mock_func)
@patch.object(pool, 'Queue')
async def test__queue_producer(self, mock_queue_cls: MagicMock):
mock_put = AsyncMock()
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
item1, item2, item3 = FOO, 420, 69
arg_iter = iter([item1, item2, item3])
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
with self.assertRaises(StopIteration):
next(arg_iter)
mock_put.reset_mock()
mock_put.side_effect = [CancelledError, None]
arg_iter = iter([item1, item2, item3])
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
mock_queue.item_processed.assert_has_calls([call(), call()])
self.assertListEqual([item2, item3], list(arg_iter))
@patch.object(pool, 'execute_optional') @patch.object(pool, 'execute_optional')
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
semaphore, mock_end_cb = Semaphore(1), MagicMock() semaphore, mock_end_cb = Semaphore(1), MagicMock()
@ -597,84 +560,85 @@ class TaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'Semaphore') @patch.object(pool, 'Semaphore')
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock, async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
mock__start_task: AsyncMock, mock_star_function: MagicMock): mock__start_task: AsyncMock, mock_star_function: MagicMock):
mock_semaphore_cls.return_value = semaphore = Semaphore(3) n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock() mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable' awaitable = 'totally an awaitable'
mock_star_function.side_effect = [awaitable, awaitable, Exception()] mock_star_function.side_effect = [awaitable, Exception(), awaitable]
arg1, arg2, bad = 123456789, 'function argument', None arg1, arg2, bad = 123456789, 'function argument', None
mock_q_maxsize = 3 args = [arg1, bad, arg2]
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]),
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then # We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3, # acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
# at the end of the loop, we expect it be locked.
self.assertTrue(semaphore.locked()) self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [ mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
]) ])
mock_star_function.assert_has_calls([ mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars), call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars), call(mock_func, bad, arg_stars=stars),
call(mock_func, bad, arg_stars=stars) call(mock_func, arg2, arg_stars=stars)
]) ])
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock()
# With a CancelledError thrown while starting a task:
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
mock_star_function.side_effect = CancelledError()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
self.assertFalse(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called()
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool, 'Queue')
@patch.object(pool, 'TaskGroupRegister') @patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start') @patch.object(pool.BaseTaskPool, '_check_start')
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock,
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
mock_create_task: MagicMock): mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_queue_cls.return_value = mock_q = MagicMock() mock__arg_consumer.return_value = fake_consumer = object()
mock_join_queue.return_value = fake_join = object() mock_create_task.return_value = fake_task = object()
mock__queue_producer.return_value = fake_producer = object()
mock__queue_consumer.return_value = fake_consumer = object()
fake_task1, fake_task2 = object(), object()
mock_create_task.side_effect = [fake_task1, fake_task2]
group_name, group_size = 'onetwothree', 0 group_name, n = 'onetwothree', 0
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock() mock__check_start.reset_mock()
group_size = 1234 n = 1234
self.task_pool._task_groups = {group_name: MagicMock()} self.task_pool._task_groups = {group_name: MagicMock()}
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock() mock__check_start.reset_mock()
self.task_pool._task_groups.clear() self.task_pool._task_groups.clear()
self.task_pool._before_gathering = []
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb))
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock_reg_cls.assert_called_once_with() mock_reg_cls.assert_called_once_with()
self.task_pool._task_groups[group_name] = mock_group_reg self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with() mock_group_reg.__aenter__.assert_awaited_once_with()
mock_queue_cls.assert_called_once_with(maxsize=group_size) mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars,
mock_join_queue.assert_called_once_with(mock_q) end_callback=end_cb, cancel_callback=cancel_cb)
self.assertListEqual([fake_join], self.task_pool._before_gathering) mock_create_task.assert_called_once_with(fake_consumer)
mock__queue_producer.assert_called_once() self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
mock_group_reg.__aexit__.assert_awaited_once() mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
@ -682,18 +646,18 @@ class TaskPoolTestCase(CommonTestCase):
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb) output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0, mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb) output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0, mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('map', mock_func) mock__generate_group_name.assert_called_once_with('map', mock_func)
@ -702,18 +666,18 @@ class TaskPoolTestCase(CommonTestCase):
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb) output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1, mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb) output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1, mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('starmap', mock_func) mock__generate_group_name.assert_called_once_with('starmap', mock_func)
@ -722,18 +686,18 @@ class TaskPoolTestCase(CommonTestCase):
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb) output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2, mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, kw_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb) output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2, mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func) mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)

View File

@ -41,10 +41,9 @@ async def main() -> None:
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2 await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start() # launches work task 3 await pool.start(1) # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 (LIFO order) pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
pool.lock() # required for the last line
await pool.gather_and_close() # awaits all tasks, then flushes the pool await pool.gather_and_close() # awaits all tasks, then flushes the pool
@ -135,11 +134,9 @@ async def main() -> None:
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
# only once there is room in the pool and no more than one other task of these new ones is running. # only once there is room in the pool and no more than one other task of these new ones is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
await pool.starmap(other_work, args_list, group_size=2) await pool.starmap(other_work, args_list, num_concurrent=2)
print("> Called `starmap`") print("> Called `starmap`")
# Now we lock the pool, so that we can safely await all our tasks. # We block, until all tasks have ended.
pool.lock()
# Finally, we block, until all tasks have ended.
print("> Calling `gather_and_close`...") print("> Calling `gather_and_close`...")
await pool.gather_and_close() await pool.gather_and_close()
print("> Done.") print("> Done.")
@ -199,7 +196,7 @@ Started TaskPool-0_Task-3
> other_work with 15 > other_work with 15
Ended TaskPool-0_Task-0 Ended TaskPool-0_Task-0
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start Started TaskPool-0_Task-4 <--- since `num_concurrent` is set to 2, Task-5 will not start
> work with 190 > work with 190
> work with 190 > work with 190
> other_work with 16 > other_work with 16

View File

@ -75,7 +75,6 @@ async def main() -> None:
control_server_task.cancel() control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.lock()
pool.stop_all() pool.stop_all()
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled. # Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,