diff --git a/README.md b/README.md
index 5cb6ffd..65e7751 100644
--- a/README.md
+++ b/README.md
@@ -15,11 +15,19 @@ If you need control over a task pool at runtime, you can launch an asynchronous
## Usage
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
+
```python
from asyncio_taskpool import SimpleTaskPool
+
...
+
+
async def work(foo, bar): ...
+
+
...
+
+
async def main():
pool = SimpleTaskPool(work, args=('xyz', 420))
await pool.start(5)
@@ -27,11 +35,11 @@ async def main():
pool.stop(3)
...
pool.lock()
- await pool.gather()
+ await pool.gather_and_close()
...
```
-Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
+Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md).
diff --git a/setup.cfg b/setup.cfg
index 0e7eb76..fc73473 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
-version = 0.3.5
+version = 0.4.0
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks
diff --git a/src/asyncio_taskpool/__main__.py b/src/asyncio_taskpool/__main__.py
index b997fdb..9e8c59b 100644
--- a/src/asyncio_taskpool/__main__.py
+++ b/src/asyncio_taskpool/__main__.py
@@ -54,10 +54,10 @@ def parse_cli() -> Dict[str, Any]:
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
- client = UnixControlClient(path=kwargs[SOCKET_PATH])
+ client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class
- client = UnixControlClient(path=kwargs[SOCKET_PATH])
+ client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py
index c4f0869..9d2b749 100644
--- a/src/asyncio_taskpool/constants.py
+++ b/src/asyncio_taskpool/constants.py
@@ -21,6 +21,9 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool'
+DEFAULT_TASK_GROUP = ''
+DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
+
CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py
index dbc28df..6f0c7d5 100644
--- a/src/asyncio_taskpool/exceptions.py
+++ b/src/asyncio_taskpool/exceptions.py
@@ -23,6 +23,10 @@ class PoolException(Exception):
pass
+class PoolIsClosed(PoolException):
+ pass
+
+
class PoolIsLocked(PoolException):
pass
@@ -43,6 +47,10 @@ class InvalidTaskID(PoolException):
pass
+class InvalidGroupName(PoolException):
+ pass
+
+
class PoolStillUnlocked(PoolException):
pass
diff --git a/src/asyncio_taskpool/group_register.py b/src/asyncio_taskpool/group_register.py
new file mode 100644
index 0000000..81b7b81
--- /dev/null
+++ b/src/asyncio_taskpool/group_register.py
@@ -0,0 +1,75 @@
+__author__ = "Daniil Fajnberg"
+__copyright__ = "Copyright © 2022 Daniil Fajnberg"
+__license__ = """GNU LGPLv3.0
+
+This file is part of asyncio-taskpool.
+
+asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
+version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
+
+asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
+If not, see ."""
+
+__doc__ = """
+This module contains the definition of the `TaskGroupRegister` class.
+"""
+
+
+from asyncio.locks import Lock
+from collections.abc import MutableSet
+from typing import Iterator, Set
+
+
+class TaskGroupRegister(MutableSet):
+ """
+ This class combines the interface of a regular `set` with that of the `asyncio.Lock`.
+
+ It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for
+ preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of
+ tasks, as well as before starting a task within the group.
+ """
+
+ def __init__(self, *task_ids: int) -> None:
+ self._ids: Set[int] = set(task_ids)
+ self._lock = Lock()
+
+ def __contains__(self, task_id: int) -> bool:
+ """Abstract method for the `MutableSet` base class."""
+ return task_id in self._ids
+
+ def __iter__(self) -> Iterator[int]:
+ """Abstract method for the `MutableSet` base class."""
+ return iter(self._ids)
+
+ def __len__(self) -> int:
+ """Abstract method for the `MutableSet` base class."""
+ return len(self._ids)
+
+ def add(self, task_id: int) -> None:
+ """Abstract method for the `MutableSet` base class."""
+ self._ids.add(task_id)
+
+ def discard(self, task_id: int) -> None:
+ """Abstract method for the `MutableSet` base class."""
+ self._ids.discard(task_id)
+
+ async def acquire(self) -> bool:
+ """Wrapper around the lock's `acquire()` method."""
+ return await self._lock.acquire()
+
+ def release(self) -> None:
+ """Wrapper around the lock's `release()` method."""
+ self._lock.release()
+
+ async def __aenter__(self) -> None:
+ """Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
+ await self._lock.acquire()
+ return None
+
+ async def __aexit__(self, exc_type, exc, tb) -> None:
+ """Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
+ self._lock.release()
diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py
index ed3bfb0..9f0abfc 100644
--- a/src/asyncio_taskpool/pool.py
+++ b/src/asyncio_taskpool/pool.py
@@ -32,19 +32,22 @@ For further details about the classes check their respective docstrings.
import logging
-from asyncio import gather
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
-from asyncio.locks import Event, Semaphore
-from asyncio.queues import Queue, QueueEmpty
-from asyncio.tasks import Task, create_task
-from functools import partial
+from asyncio.locks import Semaphore
+from asyncio.queues import QueueEmpty
+from asyncio.tasks import Task, create_task, gather
+from contextlib import suppress
+from datetime import datetime
from math import inf
-from typing import Any, Awaitable, Dict, Iterable, Iterator, List
+from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set
from . import exceptions
+from .constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
+from .group_register import TaskGroupRegister
from .helpers import execute_optional, star_function, join_queue
-from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
+from .queue_context import Queue
+from .types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
log = logging.getLogger(__name__)
@@ -62,19 +65,30 @@ class BaseTaskPool:
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
- self._enough_room: Semaphore = Semaphore()
- self.pool_size = pool_size
+ # Initialize a counter for the total number of tasks started through the pool and one for the total number of
+ # tasks cancelled through the pool.
+ self._num_started: int = 0
+ self._num_cancellations: int = 0
+
+ # Initialize flags; immutably set the name.
self._locked: bool = False
- self._counter: int = 0
- self._running: Dict[int, Task] = {}
- self._cancelled: Dict[int, Task] = {}
- self._ended: Dict[int, Task] = {}
- self._num_cancelled: int = 0
- self._num_ended: int = 0
- self._idx: int = self._add_pool(self)
+ self._closed: bool = False
self._name: str = name
+
+ # The following three dictionaries are the actual containers of the tasks controlled by the pool.
+ self._tasks_running: Dict[int, Task] = {}
+ self._tasks_cancelled: Dict[int, Task] = {}
+ self._tasks_ended: Dict[int, Task] = {}
+
+ # These next three attributes act as synchronisation primitives necessary for managing the pool.
self._before_gathering: List[Awaitable] = []
- self._interrupt_flag: Event = Event()
+ self._enough_room: Semaphore = Semaphore()
+ self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
+
+ # Finish with method/functions calls that add the pool to the internal list of pools, set its initial size,
+ # and issue a log message.
+ self._idx: int = self._add_pool(self)
+ self.pool_size = pool_size
log.debug("%s initialized", str(self))
def __str__(self) -> str:
@@ -125,34 +139,36 @@ class BaseTaskPool:
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
- At the moment a task's `end_callback` is fired, it is no longer considered to be running.
+
+ At the moment a task's `end_callback` or `cancel_callback` is fired, it is no longer considered running.
"""
- return len(self._running)
+ return len(self._tasks_running)
@property
- def num_cancelled(self) -> int:
+ def num_cancellations(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
- At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
+
+ At the moment a task's `cancel_callback` is fired, this counts as a cancellation, and the task is then
+ considered cancelled (instead of running) until its `end_callback` is fired.
"""
- return self._num_cancelled
+ return self._num_cancellations
@property
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
- At the moment a task's `end_callback` is fired, it is considered ended.
+
+ At the moment a task's `end_callback` is fired, it is considered ended and no longer running (or cancelled).
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end.
"""
- return self._num_ended
+ return len(self._tasks_ended)
@property
def num_finished(self) -> int:
- """
- Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
- """
- return self._num_ended - self._num_cancelled + len(self._cancelled)
+ """Returns the number of tasks in the pool that have finished running (without having been cancelled)."""
+ return len(self._tasks_ended) - self._num_cancellations + len(self._tasks_cancelled)
@property
def is_full(self) -> bool:
@@ -162,14 +178,61 @@ class BaseTaskPool:
"""
return self._enough_room.locked()
- # TODO: Consider adding task group names
+ def get_task_group_ids(self, group_name: str) -> Set[int]:
+ """
+ Returns the set of IDs of all tasks in the specified group.
+
+ Args:
+ group_name: Must be a name of a task group that exists within the pool.
+
+ Returns:
+ Set of integers representing the task IDs belonging to the specified group.
+
+ Raises:
+ `InvalidGroupName` if no task group named `group_name` exists in the pool.
+ """
+ try:
+ return set(self._task_groups[group_name])
+ except KeyError:
+ raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
+
+ def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None,
+ ignore_lock: bool = False) -> None:
+ """
+ Checks necessary conditions for starting a task (group) in the pool.
+
+ Either something that is expected to be a coroutine (i.e. awaitable) or something that is expected to be a
+ coroutine _function_ will be checked.
+
+ Args:
+ awaitable: If this is passed, `function` must be `None`.
+ function: If this is passed, `awaitable` must be `None`.
+ ignore_lock (optional): If `True`, a locked pool will produce no error here.
+
+ Raises:
+ `AssertionError` if both or neither of `awaitable` and `function` were passed.
+ `asyncio_taskpool.exceptions.PoolIsClosed` if the pool is closed.
+ `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a cor. / `function` not a cor. func.
+ `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`.
+ """
+ assert (awaitable is None) != (function is None)
+ if awaitable and not iscoroutine(awaitable):
+ raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
+ if function and not iscoroutinefunction(function):
+ raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
+ if self._closed:
+ raise exceptions.PoolIsClosed("You must use another pool")
+ if self._locked and not ignore_lock:
+ raise exceptions.PoolIsLocked("Cannot start new tasks")
+
def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}'
- async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
+ async def _task_cancellation(self, task_id: int, custom_callback: CancelCB = None) -> None:
"""
Universal callback to be run upon any task in the pool being cancelled.
+
Required for keeping track of running/cancelled tasks and proper logging.
Args:
@@ -180,14 +243,15 @@ class BaseTaskPool:
It is run at the end of this function with the `task_id` as its only positional argument.
"""
log.debug("Cancelling %s ...", self._task_name(task_id))
- self._cancelled[task_id] = self._running.pop(task_id)
- self._num_cancelled += 1
+ self._tasks_cancelled[task_id] = self._tasks_running.pop(task_id)
+ self._num_cancellations += 1
log.debug("Cancelled %s", self._task_name(task_id))
await execute_optional(custom_callback, args=(task_id,))
- async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
+ async def _task_ending(self, task_id: int, custom_callback: EndCB = None) -> None:
"""
Universal callback to be run upon any task in the pool ending its work.
+
Required for keeping track of running/cancelled/ended tasks and proper logging.
Also releases room in the task pool for potentially waiting tasks.
@@ -199,19 +263,20 @@ class BaseTaskPool:
It is run at the end of this function with the `task_id` as its only positional argument.
"""
try:
- self._ended[task_id] = self._running.pop(task_id)
+ self._tasks_ended[task_id] = self._tasks_running.pop(task_id)
except KeyError:
- self._ended[task_id] = self._cancelled.pop(task_id)
- self._num_ended += 1
+ self._tasks_ended[task_id] = self._tasks_cancelled.pop(task_id)
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await execute_optional(custom_callback, args=(task_id,))
- async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
- cancel_callback: CancelCallbackT = None) -> Any:
+ async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCB = None,
+ cancel_callback: CancelCB = None) -> Any:
"""
- Universal wrapper around every task to be run in the pool.
- Returns/raises whatever the wrapped coroutine does.
+ Universal wrapper around every task run in the pool that returns/raises whatever the wrapped coroutine does.
+
+ Responsible for catching cancellation and awaiting the `_task_cancellation` callback, as well as for awaiting
+ the `_task_ending` callback, after the coroutine returns or raises an exception.
Args:
awaitable:
@@ -233,16 +298,19 @@ class BaseTaskPool:
finally:
await self._task_ending(task_id, custom_callback=end_callback)
- async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None,
- cancel_callback: CancelCallbackT = None) -> int:
+ async def _start_task(self, awaitable: Awaitable, group_name: str = DEFAULT_TASK_GROUP, ignore_lock: bool = False,
+ end_callback: EndCB = None, cancel_callback: CancelCB = None) -> int:
"""
Starts a coroutine as a new task in the pool.
- This method blocks, **only if** the pool is full.
- Returns/raises whatever the wrapped coroutine does.
+
+ This method can block for a significant amount of time, **only if** the pool is full.
+ Otherwise it merely needs to acquire the `TaskGroupRegister` lock, which should never be held for a long time.
Args:
awaitable:
The actual coroutine to be run within the task pool.
+ group_name (optional):
+ Name of the task group to add the new task to; defaults to the `DEFAULT_TASK_GROUP` constant.
ignore_lock (optional):
If `True`, even if the pool is locked, the task will still be started.
end_callback (optional):
@@ -252,25 +320,20 @@ class BaseTaskPool:
A callback to execute after cancellation of the task.
It is run with the task's ID as its only positional argument.
- Raises:
- `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
- `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`.
+ Returns:
+ The ID of the newly started task.
"""
- if not iscoroutine(awaitable):
- raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
- if self._locked and not ignore_lock:
- raise exceptions.PoolIsLocked("Cannot start new tasks")
+ self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
await self._enough_room.acquire()
- task_id = self._counter
- self._counter += 1
- try:
- self._running[task_id] = create_task(
- self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
+ group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
+ async with group_reg:
+ task_id = self._num_started
+ self._num_started += 1
+ group_reg.add(task_id)
+ self._tasks_running[task_id] = create_task(
+ coro=self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id)
)
- except Exception as e:
- self._enough_room.release()
- raise e
return task_id
def _get_running_task(self, task_id: int) -> Task:
@@ -286,11 +349,11 @@ class BaseTaskPool:
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
"""
try:
- return self._running[task_id]
+ return self._tasks_running[task_id]
except KeyError:
- if self._cancelled.get(task_id):
- raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
- if self._ended.get(task_id):
+ if self._tasks_cancelled.get(task_id):
+ raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has been cancelled")
+ if self._tasks_ended.get(task_id):
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@@ -303,91 +366,111 @@ class BaseTaskPool:
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
- Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
+ Note that once a pool has been flushed (see below), IDs of tasks that have ended previously will be forgotten.
Args:
- task_ids:
- Arbitrary number of integers. Each must be an ID of a task still running within the pool.
- msg (optional):
- Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+ task_ids: Arbitrary number of integers. Each must be an ID of a task still running within the pool.
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids]
for task in tasks:
task.cancel(msg=msg)
- def cancel_all(self, msg: str = None) -> None:
+ def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
+ """
+ Removes all tasks from the specified group and cancels them, if they are still running.
+
+ Args:
+ group_name: The name of the group of tasks that shall be cancelled.
+ group_reg: The task group register object containing the task IDs.
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+ """
+ while group_reg:
+ try:
+ self._tasks_running[group_reg.pop()].cancel(msg=msg)
+ except KeyError:
+ continue
+ log.debug("%s cancelled tasks from group %s", str(self), group_name)
+
+ async def cancel_group(self, group_name: str, msg: str = None) -> None:
+ """
+ Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
+
+ Args:
+ group_name: The name of the group of tasks that shall be cancelled.
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+
+ Raises:
+ `InvalidGroupName` if no task group named `group_name` exists in the pool.
+ """
+ log.debug("%s cancelling tasks in group %s", str(self), group_name)
+ try:
+ group_reg = self._task_groups.pop(group_name)
+ except KeyError:
+ raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
+ async with group_reg:
+ self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
+ log.debug("%s forgot task group %s", str(self), group_name)
+
+ async def cancel_all(self, msg: str = None) -> None:
"""
Cancels all tasks still running within the pool.
- Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
- This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
- than the number of arguments from `args_iter`.
- In this case, those already running will be cancelled, while the following will **never even start**.
-
Args:
- msg (optional):
- Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
log.warning("%s cancelling all tasks!", str(self))
- self._interrupt_flag.set()
- for task in self._running.values():
- task.cancel(msg=msg)
+ while self._task_groups:
+ group_name, group_reg = self._task_groups.popitem()
+ async with group_reg:
+ self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
async def flush(self, return_exceptions: bool = False):
"""
- Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
+ Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
+
+ This method exists mainly to free up memory of unneeded `Task` objects.
+
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
"""
- results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
- self._ended.clear()
- self._cancelled.clear()
- if self._interrupt_flag.is_set():
- self._interrupt_flag.clear()
- return results
+ await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), return_exceptions=return_exceptions)
+ self._tasks_ended.clear()
+ self._tasks_cancelled.clear()
- async def gather(self, return_exceptions: bool = False):
+ async def gather_and_close(self, return_exceptions: bool = False):
"""
- Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
+ Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
The `lock()` method must have been called prior to this.
- Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks.
- This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller
- than the number of arguments from `args_iter`.
- In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
- blocking this method. Otherwise it will wait until they all have started.
-
- This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
- callbacks registered for a task blocks.
+ This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the
+ callbacks registered for a task blocks for whatever reason.
Args:
return_exceptions (optional): Passed directly into `gather`.
Raises:
- `asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet.
+ `PoolStillUnlocked` if the pool has not been locked yet.
"""
if not self._locked:
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering)
- results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
- return_exceptions=return_exceptions)
- self._ended.clear()
- self._cancelled.clear()
- self._running.clear()
+ await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
+ return_exceptions=return_exceptions)
+ self._tasks_ended.clear()
+ self._tasks_cancelled.clear()
+ self._tasks_running.clear()
self._before_gathering.clear()
- if self._interrupt_flag.is_set():
- self._interrupt_flag.clear()
- return results
+ self._closed = True
class TaskPool(BaseTaskPool):
"""
- General task pool class.
- Attempts to somewhat emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
+ General task pool class. Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
Tasks in the pool can all belong to the same coroutine function,
@@ -399,41 +482,187 @@ class TaskPool(BaseTaskPool):
Adding tasks blocks **only if** the pool is full at that moment.
"""
- async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
+ _QUEUE_END_SENTINEL = object()
+
+ def __init__(self, pool_size: int = inf, name: str = None) -> None:
+ super().__init__(pool_size=pool_size, name=name)
+ # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
+ # meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
+ self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
+ self._meta_tasks_cancelled: Set[Task] = set()
+
+ def _cancel_group_meta_tasks(self, group_name: str) -> None:
+ """Cancels and forgets all meta tasks associated with the task group named `group_name`."""
+ try:
+ meta_tasks = self._group_meta_tasks_running.pop(group_name)
+ except KeyError:
+ return
+ for meta_task in meta_tasks:
+ meta_task.cancel()
+ self._meta_tasks_cancelled.update(meta_tasks)
+ log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
+
+ def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
+ self._cancel_group_meta_tasks(group_name)
+ super()._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
+
+ async def cancel_group(self, group_name: str, msg: str = None) -> None:
+ """
+ Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
+
+ If any methods such as `map()` launched meta tasks belonging to that group, these meta tasks are cancelled
+ before the actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
+ **never even start**. In the case of `map()` this would mean that the `arg_iter` may be abandoned before it
+ was fully consumed (if that is even possible).
+
+ Args:
+ group_name: The name of the group of tasks (and meta tasks) that shall be cancelled.
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+
+ Raises:
+ `InvalidGroupName` if no task group named `group_name` exists in the pool.
+ """
+ await super().cancel_group(group_name=group_name, msg=msg)
+
+ async def cancel_all(self, msg: str = None) -> None:
+ """
+ Cancels all tasks still running within the pool. (This includes all meta tasks.)
+
+ If any methods such as `map()` launched meta tasks, these meta tasks are cancelled before the actual tasks are
+ cancelled. This means that any tasks "queued" to be started by a meta task will **never even start**. In the
+ case of `map()` this would mean that the `arg_iter` may be abandoned before it was fully consumed (if that is
+ even possible).
+
+ Args:
+ msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
+ """
+ await super().cancel_all(msg=msg)
+
+ def _pop_ended_meta_tasks(self) -> Set[Task]:
+ """
+ Goes through all not-cancelled meta tasks, checks if they are done already, and returns those that are.
+
+ The internal `_group_meta_tasks_running` dictionary is updated accordingly, i.e. after this method returns, the
+ values (sets) only contain meta tasks that were not done yet. In addition, names of groups that no longer
+ have any running meta tasks associated with them are removed from the keys.
+ """
+ obsolete_keys, ended_meta_tasks = [], set()
+ for group_name in self._group_meta_tasks_running.keys():
+ still_running = set()
+ while self._group_meta_tasks_running[group_name]:
+ meta_task = self._group_meta_tasks_running[group_name].pop()
+ if meta_task.done():
+ ended_meta_tasks.add(meta_task)
+ else:
+ still_running.add(meta_task)
+ if still_running:
+ self._group_meta_tasks_running[group_name] = still_running
+ else:
+ obsolete_keys.append(group_name)
+ # If a group no longer has running meta tasks associated with, we can remove its name from the dictionary.
+ for group_name in obsolete_keys:
+ del self._group_meta_tasks_running[group_name]
+ return ended_meta_tasks
+
+ async def flush(self, return_exceptions: bool = False):
+ """
+ Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
+
+ This method exists mainly to free up memory of unneeded `Task` objects. It also gets rid of unneeded meta tasks.
+
+ This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
+ callbacks registered for the tasks block.
+
+ Args:
+ return_exceptions (optional): Passed directly into `gather`.
+ """
+ await super().flush(return_exceptions=return_exceptions)
+ with suppress(CancelledError):
+ await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
+ return_exceptions=return_exceptions)
+ self._meta_tasks_cancelled.clear()
+
+ async def gather_and_close(self, return_exceptions: bool = False):
+ """
+ Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
+
+ The `lock()` method must have been called prior to this.
+
+ Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
+ tasks launched my methods such as `map()`, which ends by itself, only once the `arg_iter` is fully consumed,
+ which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
+ this, make sure to call `cancel_all()` prior to this.
+
+
+ This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
+ the callbacks registered for a task blocks for whatever reason.
+
+ Args:
+ return_exceptions (optional): Passed directly into `gather`.
+
+ Raises:
+ `PoolStillUnlocked` if the pool has not been locked yet.
+ """
+ await super().gather_and_close(return_exceptions=return_exceptions)
+ not_cancelled_meta_tasks = set()
+ while self._group_meta_tasks_running:
+ _, meta_tasks = self._group_meta_tasks_running.popitem()
+ not_cancelled_meta_tasks.update(meta_tasks)
+ with suppress(CancelledError):
+ await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
+ self._meta_tasks_cancelled.clear()
+
+ @staticmethod
+ def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
+ """
+ Creates a task group identifier that includes the current datetime.
+
+ Args:
+ prefix: The start of the name; will be followed by an underscore.
+ coroutine_function: The function representing the task group.
+
+ Returns:
+ The constructed 'prefix_function_datetime' string to name a task group.
+ """
+ return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}'
+
+ async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
+ num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
- This method blocks, **only if** the pool is full.
+ This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
+ group_name:
+ Name of the task group to add the new task to.
func:
The coroutine function to be run as a task within the task pool.
args (optional):
The positional arguments to pass into the function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
+ num (optional):
+ The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the task's ID as its only positional argument.
-
- Returns:
- The newly spawned task's ID within the pool.
"""
if kwargs is None:
kwargs = {}
- return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
+ await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
+ cancel_callback=cancel_callback) for _ in range(num)))
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
+ group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool.
- Each coroutine looks like `func(*args, **kwargs)`.
- This method blocks, **only if** there is not enough room in the pool for the desired number of new tasks.
+ Each coroutine looks like `func(*args, **kwargs)`. All the new tasks are added to the same task group.
+ This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
func:
@@ -444,6 +673,8 @@ class TaskPool(BaseTaskPool):
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
+ group_name (optional):
+ Name of the task group to add the new tasks to.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
@@ -452,176 +683,129 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Returns:
- The newly spawned tasks' IDs within the pool as a list of integers.
+ The name of the task group that the newly spawned tasks have been added to.
Raises:
+ `PoolIsClosed` if the pool is closed.
`NotCoroutine` if `func` is not a coroutine function.
- `PoolIsLocked` if the pool has been locked already.
+ `PoolIsLocked` if the pool has been locked.
"""
- ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
- # TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
- assert isinstance(ids, list)
- return ids
+ self._check_start(function=func)
+ if group_name is None:
+ group_name = self._generate_group_name('apply', func)
+ group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
+ async with group_reg:
+ task = create_task(self._apply_num(group_name, func, args, kwargs, num, end_callback, cancel_callback))
+ await task
+ return group_name
- async def _queue_producer(self, q: Queue, args_iter: Iterator[Any]) -> None:
+ @classmethod
+ async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None:
"""
Keeps the arguments queue from `_map()` full as long as the iterator has elements.
- If the `_interrupt_flag` gets set, the loop ends prematurely.
+
+ Intended to be run as a meta task of a specific group.
Args:
- q:
- The queue of function arguments to consume for starting the next task.
- args_iter:
- The iterator of function arguments to put into the queue.
- """
- for arg in args_iter:
- if self._interrupt_flag.is_set():
- break
- await q.put(arg) # This blocks as long as the queue is full.
-
- async def _queue_consumer(self, q: Queue, first_batch_started: Event, func: CoroutineFunc, arg_stars: int = 0,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
- """
- Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`.
- Partially constructs the `_queue_callback` function with the same arguments.
-
- Args:
- q:
- The queue of function arguments to consume for starting the next task.
- first_batch_started:
- The event flag to wait for, before launching the next consumer.
- It can only set by the `_map()` method, which happens after the first batch of task has been started.
- func:
- The coroutine function to use for spawning the tasks within the task pool.
- arg_stars (optional):
- Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
- end_callback (optional):
- The actual callback specified to execute after the task (and the next one) has ended.
- It is run with the task's ID as its only positional argument.
- cancel_callback (optional):
- The callback that was specified to execute after cancellation of the task (and the next one).
- It is run with the task's ID as its only positional argument.
+ arg_queue: The queue of function arguments to consume for starting a new task.
+ arg_iter: The iterator of function arguments to put into the queue.
+ group_name: Name of the task group associated with this producer.
"""
try:
- arg = q.get_nowait()
- except QueueEmpty:
- return
- try:
- await self._start_task(
- star_function(func, arg, arg_stars=arg_stars),
- ignore_lock=True,
- end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started,
- func=func, arg_stars=arg_stars, end_callback=end_callback,
- cancel_callback=cancel_callback),
- cancel_callback=cancel_callback
- )
+ for arg in arg_iter:
+ await arg_queue.put(arg) # This blocks as long as the queue is full.
+ except CancelledError:
+ # This means that no more tasks are supposed to be created from this `_map()` call;
+ # thus, we can immediately drain the entire queue and forget about the rest of the arguments.
+ log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
+ while True:
+ try:
+ arg_queue.get_nowait()
+ arg_queue.item_processed()
+ except QueueEmpty:
+ return
finally:
- q.task_done()
+ await arg_queue.put(cls._QUEUE_END_SENTINEL)
- async def _queue_callback(self, task_id: int, q: Queue, first_batch_started: Event, func: CoroutineFunc,
- arg_stars: int = 0, end_callback: EndCallbackT = None,
- cancel_callback: CancelCallbackT = None) -> None:
+ @staticmethod
+ def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
+ """Returns a wrapped `end_callback` for each `_queue_consumer()` task that will release the `map_semaphore`."""
+ async def release_callback(task_id: int) -> None:
+ map_semaphore.release()
+ await execute_optional(actual_end_callback, args=(task_id,))
+ return release_callback
+
+ async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0,
+ end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
- Wrapper around an end callback function passed into the `_map()` method.
- Triggers the next `_queue_consumer` with the same arguments.
+ Consumes arguments from the queue from `_map()` and keeps a limited number of tasks working on them.
+
+ The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before
+ a new task can be started, and which must be released when one of these tasks ends.
+
+ Intended to be run as a meta task of a specific group.
Args:
- task_id:
- The ID of the ending task.
- q:
- The queue of function arguments to consume for starting the next task.
- first_batch_started:
- The event flag to wait for, before launching the next consumer.
- It can only set by the `_map()` method, which happens after the first batch of task has been started.
+ arg_queue:
+ The queue of function arguments to consume for starting a new task.
+ group_name:
+ Name of the associated task group; passed into the `_start_task()` method.
func:
- The coroutine function to use for spawning the tasks within the task pool.
+ The coroutine function to use for spawning the new tasks within the task pool.
arg_stars (optional):
- Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
+ Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
- It is run with the `task_id` as its only positional argument.
+ It is run with the task's ID as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
- It is run with the `task_id` as its only positional argument.
+ It is run with the task's ID as its only positional argument.
"""
- await first_batch_started.wait()
- await self._queue_consumer(q, first_batch_started, func, arg_stars,
- end_callback=end_callback, cancel_callback=cancel_callback)
- await execute_optional(end_callback, args=(task_id,))
+ map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in `_map()`
+ release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
+ while True:
+ # The following line blocks **only if** the number of running tasks spawned by this method has reached the
+ # specified maximum as determined in the `_map()` method.
+ await map_semaphore.acquire()
+ # We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
+ async with arg_queue as next_arg:
+ if next_arg is self._QUEUE_END_SENTINEL:
+ # The `_queue_producer()` either reached the last argument or was cancelled.
+ return
+ await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
+ ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
- def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
- """
- Helper function for `_map()`.
- Takes the iterable of function arguments `args_iter` and adds up to `group_size` to a new `asyncio.Queue`.
- The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
-
- If the iterable contains less than `group_size` elements, nothing else happens; otherwise the `_queue_producer`
- is started as a separate task with the arguments queue and and iterator of the remaining arguments.
-
- Args:
- args_iter:
- The iterable of function arguments passed into `_map()` to use for creating the new tasks.
- num_tasks:
- The maximum number of the new tasks to run concurrently that was passed into `_map()`.
-
- Returns:
- The newly created and filled arguments queue for spawning new tasks.
- """
- # Setting the `maxsize` of the queue to `group_size` will ensure that no more than `group_size` tasks will run
- # concurrently because the size of the queue is what will determine the number of immediately started tasks in
- # the `_map()` method and each of those will only ever start (at most) one other task upon ending.
- args_queue = Queue(maxsize=num_tasks)
- self._before_gathering.append(join_queue(args_queue))
- args_iter = iter(args_iter)
- try:
- # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
- # tasks, which will be at most `group_size` (meaning the queue will be full).
- for i in range(num_tasks):
- args_queue.put_nowait(next(args_iter))
- except StopIteration:
- # If we get here, this means that the number of elements in the arguments iterator was less than the
- # specified `group_size`. Still, the number of tasks to start immediately will be the size of the queue.
- # The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
- pass
- else:
- # There may be more elements in the arguments iterator, so we need the `_queue_producer`.
- # It will have exclusive access to the `args_iter` from now on.
- # Since the queue is full already, it will wait until one of the tasks in the first batch ends,
- # before putting the next item in it.
- create_task(self._queue_producer(args_queue, args_iter))
- return args_queue
-
- async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, group_size: int = 1,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
+ async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
+ end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
- Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `args_iter`.
+ Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
+ All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
- any given moment in time. Assuming the number of elements produced by `args_iter` is greater than `group_size`,
- this method will block **only** until the first `group_size` tasks have been **started**, before returning.
- (If the number of elements from `args_iter` is smaller than `group_size`, this method will return as soon as
- all of them have been started.)
+ any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
+ (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
+ of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
+ concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
- As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
- the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
- limit, this ensures that the number of tasks running concurrently as a result of this method call is always
- equal to `group_size` (except for when `args_iter` is exhausted of course).
-
- Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
-
- This method sets up an internal arguments queue which is continuously filled while consuming the `args_iter`.
+ This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
+ Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
+ aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
+ not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
+ the `pool_size` and the `group_size`.
Args:
+ group_name:
+ Name of the task group to add the new tasks to. It must be a name that doesn't exist yet.
+ group_size:
+ The maximum number new tasks spawned by this method to run concurrently.
func:
The coroutine function to use for spawning the new tasks within the task pool.
- args_iter:
+ arg_iter:
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
- arg_stars (optional):
+ arg_stars:
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
- group_size (optional):
- The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
@@ -630,44 +814,46 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Raises:
- `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked.
+ `ValueError` if `group_size` is less than 1.
+ `asyncio_taskpool.exceptions.InvalidGroupName` if a group named `group_name` exists in the pool.
"""
- if not self._locked:
- raise exceptions.PoolIsLocked("Cannot start new tasks")
- args_queue = self._set_up_args_queue(args_iter, group_size)
- # We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
- # `_queue_callback` triggered by one or more of them.
- # This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more
- # than one element, and the pool remains full until after the first task of the first batch ends. Then the
- # callback might trigger the next `_queue_consumer` before this method can, which will keep it blocked.
- first_batch_started = Event()
- for _ in range(args_queue.qsize()):
- # This is where blocking can occur, if the pool is full.
- await self._queue_consumer(args_queue, first_batch_started, func,
- arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
- # Now the callbacks can immediately trigger more tasks.
- first_batch_started.set()
+ self._check_start(function=func)
+ if group_size < 1:
+ raise ValueError(f"Group size must be a positive integer.")
+ if group_name in self._task_groups.keys():
+ raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
+ self._task_groups[group_name] = group_reg = TaskGroupRegister()
+ async with group_reg:
+ # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
+ # `_queue_producer()`; that way an argument
+ arg_queue = Queue(maxsize=group_size)
+ self._before_gathering.append(join_queue(arg_queue))
+ meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
+ # Start the producer and consumer meta tasks.
+ meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
+ meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars,
+ end_callback, cancel_callback)))
- async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
+ async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None,
+ end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
+ All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
- any given moment in time. Assuming the number of elements produced by `arg_iter` is greater than `group_size`,
- this method will block **only** until the first `group_size` tasks have been **started**, before returning.
- (If the number of elements from `arg_iter` is smaller than `group_size`, this method will return as soon as
- all of them have been started.)
+ any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
+ (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
+ of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
+ concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
- As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in
- the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a
- limit, this ensures that the number of tasks running concurrently as a result of this method call is always
- equal to `group_size` (except for when `arg_iter` is exhausted of course).
-
- Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
+ This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
+ Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
+ aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
+ not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
+ the `pool_size` and the `group_size`.
Args:
func:
@@ -676,6 +862,8 @@ class TaskPool(BaseTaskPool):
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
group_size (optional):
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
+ group_name (optional):
+ Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
@@ -683,32 +871,48 @@ class TaskPool(BaseTaskPool):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
+ Returns:
+ The name of the task group that the newly spawned tasks will be added to.
+
Raises:
- `PoolIsLocked` if the pool has been locked.
+ `PoolIsClosed` if the pool is closed.
`NotCoroutine` if `func` is not a coroutine function.
+ `PoolIsLocked` if the pool has been locked.
+ `ValueError` if `group_size` is less than 1.
+ `InvalidGroupName` if a group named `group_name` exists in the pool.
"""
- await self._map(func, arg_iter, arg_stars=0, group_size=group_size,
+ if group_name is None:
+ group_name = self._generate_group_name('map', func)
+ await self._map(group_name, group_size, func, arg_iter, 0,
end_callback=end_callback, cancel_callback=cancel_callback)
+ return group_name
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
+ group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
positional arguments to the function.
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
"""
- await self._map(func, args_iter, arg_stars=1, group_size=group_size,
+ if group_name is None:
+ group_name = self._generate_group_name('starmap', func)
+ await self._map(group_name, group_size, func, args_iter, 1,
end_callback=end_callback, cancel_callback=cancel_callback)
+ return group_name
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
+ group_name: str = None, end_callback: EndCB = None,
+ cancel_callback: CancelCB = None) -> str:
"""
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
keyword-arguments to the function.
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
"""
- await self._map(func, kwargs_iter, arg_stars=2, group_size=group_size,
+ if group_name is None:
+ group_name = self._generate_group_name('doublestarmap', func)
+ await self._map(group_name, group_size, func, kwargs_iter, 2,
end_callback=end_callback, cancel_callback=cancel_callback)
+ return group_name
class SimpleTaskPool(BaseTaskPool):
@@ -729,7 +933,7 @@ class SimpleTaskPool(BaseTaskPool):
"""
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
- end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
+ end_callback: EndCB = None, cancel_callback: CancelCB = None,
pool_size: int = inf, name: str = None) -> None:
"""
@@ -756,8 +960,8 @@ class SimpleTaskPool(BaseTaskPool):
self._func: CoroutineFunc = func
self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
- self._end_callback: EndCallbackT = end_callback
- self._cancel_callback: CancelCallbackT = cancel_callback
+ self._end_callback: EndCB = end_callback
+ self._cancel_callback: CancelCB = cancel_callback
super().__init__(pool_size=pool_size, name=name)
@property
@@ -784,7 +988,7 @@ class SimpleTaskPool(BaseTaskPool):
If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled.
"""
ids = []
- for i, task_id in enumerate(reversed(self._running)):
+ for i, task_id in enumerate(reversed(self._tasks_running)):
if i >= num:
break # We got the desired number of task IDs, there may well be more tasks left to keep running
ids.append(task_id)
diff --git a/src/asyncio_taskpool/queue_context.py b/src/asyncio_taskpool/queue_context.py
new file mode 100644
index 0000000..29959bf
--- /dev/null
+++ b/src/asyncio_taskpool/queue_context.py
@@ -0,0 +1,58 @@
+__author__ = "Daniil Fajnberg"
+__copyright__ = "Copyright © 2022 Daniil Fajnberg"
+__license__ = """GNU LGPLv3.0
+
+This file is part of asyncio-taskpool.
+
+asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
+version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
+
+asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
+If not, see ."""
+
+__doc__ = """
+This module contains the definition of an `asyncio.Queue` subclass.
+"""
+
+
+from asyncio.queues import Queue as _Queue
+from typing import Any
+
+
+class Queue(_Queue):
+ """This just adds a little syntactic sugar to the `asyncio.Queue`."""
+
+ def item_processed(self) -> None:
+ """
+ Does exactly the same as `task_done()`.
+
+ This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing,
+ invites confusion, and immensely reduces readability (in the context of this library). And readability counts.
+ """
+ self.task_done()
+
+ async def __aenter__(self) -> Any:
+ """
+ Implements an asynchronous context manager for the queue.
+
+ Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned.
+ It allows writing code this way:
+ >>> queue = Queue()
+ >>> ...
+ >>> async with queue as item:
+ >>> ...
+ """
+ return await self.get()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
+ """
+ Implements an asynchronous context manager for the queue.
+
+ Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
+ but in some situations it makes the codes much cleaner.
+ """
+ self.item_processed()
diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py
index f6177e7..b8a5194 100644
--- a/src/asyncio_taskpool/types.py
+++ b/src/asyncio_taskpool/types.py
@@ -32,8 +32,8 @@ KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
-EndCallbackT = Callable
-CancelCallbackT = Callable
+EndCB = Callable
+CancelCB = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
diff --git a/tests/test_pool.py b/tests/test_pool.py
index bbbc64f..aef5fc0 100644
--- a/tests/test_pool.py
+++ b/tests/test_pool.py
@@ -18,19 +18,20 @@ __doc__ = """
Unittests for the `asyncio_taskpool.pool` module.
"""
-
-import asyncio
from asyncio.exceptions import CancelledError
-from asyncio.queues import Queue
+from asyncio.locks import Semaphore
+from asyncio.queues import QueueEmpty
+from datetime import datetime
from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions
+from asyncio_taskpool.constants import DATETIME_FORMAT
-EMPTY_LIST, EMPTY_DICT = [], {}
-FOO, BAR = 'foo', 'bar'
+EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
+FOO, BAR, BAZ = 'foo', 'bar', 'baz'
class TestException(Exception):
@@ -45,19 +46,12 @@ class CommonTestCase(IsolatedAsyncioTestCase):
task_pool: pool.BaseTaskPool
log_lvl: int
- @classmethod
- def setUpClass(cls) -> None:
- cls.log_lvl = pool.log.level
- pool.log.setLevel(999)
-
- @classmethod
- def tearDownClass(cls) -> None:
- pool.log.setLevel(cls.log_lvl)
-
def get_task_pool_init_params(self) -> dict:
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
def setUp(self) -> None:
+ self.log_lvl = pool.log.level
+ pool.log.setLevel(999)
self._pools = self.TEST_CLASS._pools
# These three methods are called during initialization, so we mock them by default during setup:
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
@@ -76,6 +70,7 @@ class CommonTestCase(IsolatedAsyncioTestCase):
self._add_pool_patcher.stop()
self.pool_size_patcher.stop()
self.dunder_str_patcher.stop()
+ pool.log.setLevel(self.log_lvl)
class BaseTaskPoolTestCase(CommonTestCase):
@@ -88,19 +83,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self):
- self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
+ self.assertEqual(0, self.task_pool._num_started)
+ self.assertEqual(0, self.task_pool._num_cancellations)
+
self.assertFalse(self.task_pool._locked)
- self.assertEqual(0, self.task_pool._counter)
- self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
- self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
- self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
- self.assertEqual(0, self.task_pool._num_cancelled)
- self.assertEqual(0, self.task_pool._num_ended)
- self.assertEqual(self.mock_idx, self.task_pool._idx)
+ self.assertFalse(self.task_pool._closed)
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
+
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
+
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
- self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
- self.assertFalse(self.task_pool._interrupt_flag.is_set())
+ self.assertIsInstance(self.task_pool._enough_room, Semaphore)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
+
+ self.assertEqual(self.mock_idx, self.task_pool._idx)
+
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with()
@@ -143,26 +142,56 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertFalse(self.task_pool._locked)
def test_num_running(self):
- self.task_pool._running = ['foo', 'bar', 'baz']
+ self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_running)
- def test_num_cancelled(self):
- self.task_pool._num_cancelled = 3
- self.assertEqual(3, self.task_pool.num_cancelled)
+ def test_num_cancellations(self):
+ self.task_pool._num_cancellations = 3
+ self.assertEqual(3, self.task_pool.num_cancellations)
def test_num_ended(self):
- self.task_pool._num_ended = 3
+ self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self):
- self.task_pool._num_cancelled = cancelled = 69
- self.task_pool._num_ended = ended = 420
- self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
- self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
+ self.task_pool._num_cancellations = num_cancellations = 69
+ num_ended = 420
+ self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
+ self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
+ self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
+ def test_get_task_group_ids(self):
+ group_name, ids = 'abcdef', [1, 2, 3]
+ self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
+ self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name))
+ with self.assertRaises(exceptions.InvalidGroupName):
+ self.task_pool.get_task_group_ids('something else')
+
+ async def test__check_start(self):
+ self.task_pool._closed = True
+ mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
+ try:
+ with self.assertRaises(AssertionError):
+ self.task_pool._check_start(awaitable=None, function=None)
+ with self.assertRaises(AssertionError):
+ self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function)
+ with self.assertRaises(exceptions.NotCoroutine):
+ self.task_pool._check_start(awaitable=mock_coroutine_function, function=None)
+ with self.assertRaises(exceptions.NotCoroutine):
+ self.task_pool._check_start(awaitable=None, function=mock_coroutine)
+ with self.assertRaises(exceptions.PoolIsClosed):
+ self.task_pool._check_start(awaitable=mock_coroutine, function=None)
+ self.task_pool._closed = False
+ self.task_pool._locked = True
+ with self.assertRaises(exceptions.PoolIsLocked):
+ self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
+ self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True))
+ finally:
+ await mock_coroutine
+
def test__task_name(self):
i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@@ -171,12 +200,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
- self.task_pool._num_cancelled = cancelled = 3
- self.task_pool._running[task_id] = mock_task
+ self.task_pool._num_cancellations = cancelled = 3
+ self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
- self.assertNotIn(task_id, self.task_pool._running)
- self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
- self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
+ self.assertNotIn(task_id, self.task_pool._tasks_running)
+ self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
+ self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@@ -184,15 +213,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
- self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
- self.task_pool._running[task_id] = mock_task
+ self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
- self.assertNotIn(task_id, self.task_pool._running)
- self.assertEqual(mock_task, self.task_pool._ended[task_id])
- self.assertEqual(ended + 1, self.task_pool._num_ended)
+ self.assertNotIn(task_id, self.task_pool._tasks_running)
+ self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@@ -200,11 +227,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock_execute_optional.reset_mock()
# End cancelled task:
- self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
+ self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
- self.assertNotIn(task_id, self.task_pool._cancelled)
- self.assertEqual(mock_task, self.task_pool._ended[task_id])
- self.assertEqual(ended + 2, self.task_pool._num_ended)
+ self.assertNotIn(task_id, self.task_pool._tasks_cancelled)
+ self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@@ -246,92 +272,52 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
- async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock,
- mock_create_task: MagicMock):
- def reset_mocks() -> None:
- mock__task_name.reset_mock()
- mock__task_wrapper.reset_mock()
- mock_create_task.reset_mock()
-
+ @patch.object(pool, 'TaskGroupRegister')
+ @patch.object(pool.BaseTaskPool, '_check_start')
+ async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock,
+ mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
+ mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
- mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
- self.task_pool._counter = count = 123
+ mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
+ self.task_pool._num_started = count = 123
self.task_pool._enough_room._value = room = 123
-
- def check_nothing_changed() -> None:
- self.assertEqual(count, self.task_pool._counter)
- self.assertNotIn(count, self.task_pool._running)
- self.assertEqual(room, self.task_pool._enough_room._value)
- mock__task_name.assert_not_called()
- mock__task_wrapper.assert_not_called()
- mock_create_task.assert_not_called()
- reset_mocks()
-
- with self.assertRaises(exceptions.NotCoroutine):
- await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
- check_nothing_changed()
-
- self.task_pool._locked = True
- ignore_closed = False
- mock_awaitable = mock_coroutine()
- with self.assertRaises(exceptions.PoolIsLocked):
- await self.task_pool._start_task(mock_awaitable, ignore_closed,
- end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
- await mock_awaitable
- check_nothing_changed()
-
- ignore_closed = True
- mock_awaitable = mock_coroutine()
- output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
+ group_name, ignore_lock = 'testgroup', True
+ output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
- await mock_awaitable
self.assertEqual(count, output)
- self.assertEqual(count + 1, self.task_pool._counter)
- self.assertEqual(mock_task, self.task_pool._running[count])
+ mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock)
self.assertEqual(room - 1, self.task_pool._enough_room._value)
+ self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name])
+ mock_reg_cls.assert_called_once_with()
+ mock_group_reg.__aenter__.assert_awaited_once_with()
+ mock_group_reg.add.assert_called_once_with(count)
mock__task_name.assert_called_once_with(count)
- mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
- mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
- reset_mocks()
- self.task_pool._counter = count
- self.task_pool._enough_room._value = room
- del self.task_pool._running[count]
-
- mock_awaitable = mock_coroutine()
- mock_create_task.side_effect = test_exception = TestException()
- with self.assertRaises(TestException) as e:
- await self.task_pool._start_task(mock_awaitable, ignore_closed,
- end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
- self.assertEqual(test_exception, e)
- await mock_awaitable
- self.assertEqual(count + 1, self.task_pool._counter)
- self.assertNotIn(count, self.task_pool._running)
- self.assertEqual(room, self.task_pool._enough_room._value)
- mock__task_name.assert_called_once_with(count)
- mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
- mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
+ mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb)
+ mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO)
+ self.assertEqual(mock_task, self.task_pool._tasks_running[count])
+ mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
- self.task_pool._running[task_id] = mock_task
+ self.task_pool._tasks_running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
- self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
+ self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
- self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
+ self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
- del self.task_pool._ended[task_id]
+ del self.task_pool._tasks_ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@@ -344,263 +330,416 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
- def test_cancel_all(self):
- mock_task1, mock_task2 = MagicMock(), MagicMock()
- self.task_pool._running = {1: mock_task1, 2: mock_task2}
- assert not self.task_pool._interrupt_flag.is_set()
- self.assertIsNone(self.task_pool.cancel_all(FOO))
- self.assertTrue(self.task_pool._interrupt_flag.is_set())
- mock_task1.cancel.assert_called_once_with(msg=FOO)
- mock_task2.cancel.assert_called_once_with(msg=FOO)
+ def test__cancel_and_remove_all_from_group(self):
+ task_id = 555
+ mock_cancel = MagicMock()
+ self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
+
+ class MockRegister(set, MagicMock):
+ pass
+ self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
+ mock_cancel.assert_called_once_with(msg=FOO)
+
+ @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
+ async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
+ mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
+ mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
+ self.task_pool._task_groups[FOO] = mock_group_reg
+ with self.assertRaises(exceptions.InvalidGroupName):
+ await self.task_pool.cancel_group(BAR)
+ mock__cancel_and_remove_all_from_group.assert_not_called()
+ mock_grp_aenter.assert_not_called()
+ mock_grp_aexit.assert_not_called()
+ self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
+ mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
+ mock_grp_aenter.assert_awaited_once_with()
+ mock_grp_aexit.assert_awaited_once()
+
+ @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
+ async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
+ mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
+ mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
+ self.task_pool._task_groups[BAR] = mock_group_reg
+ self.assertIsNone(await self.task_pool.cancel_all(FOO))
+ mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
+ mock_grp_aenter.assert_awaited_once_with()
+ mock_grp_aexit.assert_awaited_once()
async def test_flush(self):
- test_exception = TestException()
- mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
- self.task_pool._ended = {123: mock_ended_func()}
- self.task_pool._cancelled = {456: mock_cancelled_func()}
- self.task_pool._interrupt_flag.set()
- output = await self.task_pool.flush(return_exceptions=True)
- self.assertListEqual([FOO, test_exception], output)
- self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
- self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
- self.assertFalse(self.task_pool._interrupt_flag.is_set())
+ mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
+ self.task_pool._tasks_ended = {123: mock_ended_func()}
+ self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
+ self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
+ mock_ended_func.assert_awaited_once_with()
+ mock_cancelled_func.assert_awaited_once_with()
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
- self.task_pool._ended = {123: mock_ended_func()}
- self.task_pool._cancelled = {456: mock_cancelled_func()}
- output = await self.task_pool.flush(return_exceptions=True)
- self.assertListEqual([FOO, test_exception], output)
- self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
- self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
+ async def test_gather_and_close(self):
+ mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
+ mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
+ self.task_pool._before_gathering = before_gather = [mock_before_gather()]
+ self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
+ self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
+ self.task_pool._tasks_running = running = {789: mock_running_func()}
- async def test_gather(self):
- test_exception = TestException()
- mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
- mock_running_func = AsyncMock(return_value=BAR)
- mock_queue_join = AsyncMock()
- self.task_pool._before_gathering = before_gather = [mock_queue_join()]
- self.task_pool._ended = ended = {123: mock_ended_func()}
- self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
- self.task_pool._running = running = {789: mock_running_func()}
- self.task_pool._interrupt_flag.set()
-
- assert not self.task_pool._locked
with self.assertRaises(exceptions.PoolStillUnlocked):
- await self.task_pool.gather()
- self.assertDictEqual(self.task_pool._ended, ended)
- self.assertDictEqual(self.task_pool._cancelled, cancelled)
- self.assertDictEqual(self.task_pool._running, running)
- self.assertListEqual(self.task_pool._before_gathering, before_gather)
- self.assertTrue(self.task_pool._interrupt_flag.is_set())
+ await self.task_pool.gather_and_close()
+ self.assertDictEqual(ended, self.task_pool._tasks_ended)
+ self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
+ self.assertDictEqual(running, self.task_pool._tasks_running)
+ self.assertListEqual(before_gather, self.task_pool._before_gathering)
+ self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True
-
- def check_assertions(output) -> None:
- self.assertListEqual([FOO, test_exception, BAR], output)
- self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
- self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
- self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
- self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
- self.assertFalse(self.task_pool._interrupt_flag.is_set())
-
- check_assertions(await self.task_pool.gather(return_exceptions=True))
-
- self.task_pool._before_gathering = [mock_queue_join()]
- self.task_pool._ended = {123: mock_ended_func()}
- self.task_pool._cancelled = {456: mock_cancelled_func()}
- self.task_pool._running = {789: mock_running_func()}
- check_assertions(await self.task_pool.gather(return_exceptions=True))
+ self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
+ mock_before_gather.assert_awaited_once_with()
+ mock_ended_func.assert_awaited_once_with()
+ mock_cancelled_func.assert_awaited_once_with()
+ mock_running_func.assert_awaited_once_with()
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
+ self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
+ self.assertTrue(self.task_pool._closed)
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
- @patch.object(pool.TaskPool, '_start_task')
- async def test__apply_one(self, mock__start_task: AsyncMock):
- mock__start_task.return_value = expected_output = 12345
- mock_awaitable = MagicMock()
- mock_func = MagicMock(return_value=mock_awaitable)
- args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
- end_cb, cancel_cb = MagicMock(), MagicMock()
- output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
+ def setUp(self) -> None:
+ self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
+ self.base_class_init = self.base_class_init_patcher.start()
+ super().setUp()
+
+ def tearDown(self) -> None:
+ self.base_class_init_patcher.stop()
+ super().tearDown()
+
+ def test_init(self):
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
+ self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
+
+ def test__cancel_group_meta_tasks(self):
+ mock_task1, mock_task2 = MagicMock(), MagicMock()
+ self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2}
+ self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO))
+ self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running)
+ self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
+ mock_task1.cancel.assert_not_called()
+ mock_task2.cancel.assert_not_called()
+
+ self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR))
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
+ self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled)
+ mock_task1.cancel.assert_called_once_with()
+ mock_task2.cancel.assert_called_once_with()
+
+ @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
+ @patch.object(pool.TaskPool, '_cancel_group_meta_tasks')
+ def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock,
+ mock_base__cancel_and_remove_all_from_group: MagicMock):
+ group_name, group_reg, msg = 'xyz', MagicMock(), FOO
+ self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg))
+ mock__cancel_group_meta_tasks.assert_called_once_with(group_name)
+ mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg)
+
+ @patch.object(pool.BaseTaskPool, 'cancel_group')
+ async def test_cancel_group(self, mock_base_cancel_group: AsyncMock):
+ group_name, msg = 'abc', 'xyz'
+ await self.task_pool.cancel_group(group_name, msg=msg)
+ mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg)
+
+ @patch.object(pool.BaseTaskPool, 'cancel_all')
+ async def test_cancel_all(self, mock_base_cancel_all: AsyncMock):
+ msg = 'xyz'
+ await self.task_pool.cancel_all(msg=msg)
+ mock_base_cancel_all.assert_awaited_once_with(msg=msg)
+
+ def test__pop_ended_meta_tasks(self):
+ mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
+ self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1}
+ mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True)
+ self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3}
+ expected_output = {mock_done_task1, mock_done_task2, mock_done_task3}
+ output = self.task_pool._pop_ended_meta_tasks()
+ self.assertSetEqual(expected_output, output)
+ self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
+
+ @patch.object(pool.TaskPool, '_pop_ended_meta_tasks')
+ @patch.object(pool.BaseTaskPool, 'flush')
+ async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock):
+ mock_ended_meta_task = AsyncMock()
+ mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
+ mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
+ self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
+ self.assertIsNone(await self.task_pool.flush(return_exceptions=False))
+ mock_base_flush.assert_awaited_once_with(return_exceptions=False)
+ mock__pop_ended_meta_tasks.assert_called_once_with()
+ mock_ended_meta_task.assert_awaited_once_with()
+ mock_cancelled_meta_task.assert_awaited_once_with()
+ self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
+
+ @patch.object(pool.BaseTaskPool, 'gather_and_close')
+ async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
+ mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
+ self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
+ mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
+ self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
+ self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
+ mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
+ mock_meta_task1.assert_awaited_once_with()
+ mock_meta_task2.assert_awaited_once_with()
+ mock_cancelled_meta_task.assert_awaited_once_with()
+ self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
+ self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
+
+ @patch.object(pool, 'datetime')
+ def test__generate_group_name(self, mock_datetime: MagicMock):
+ prefix, func = 'x y z', AsyncMock(__name__=BAR)
+ dt = datetime(1776, 7, 4, 0, 0, 1)
+ mock_datetime.now = MagicMock(return_value=dt)
+ expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
+ output = pool.TaskPool._generate_group_name(prefix, func)
self.assertEqual(expected_output, output)
- mock_func.assert_called_once_with(*args, **kwargs)
- mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
+
+ @patch.object(pool.TaskPool, '_start_task')
+ async def test__apply_num(self, mock__start_task: AsyncMock):
+ group_name = FOO + BAR
+ mock_awaitable = object()
+ mock_func = MagicMock(return_value=mock_awaitable)
+ args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
+ end_cb, cancel_cb = MagicMock(), MagicMock()
+ self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
+ mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
+ mock__start_task.assert_has_awaits(3 * [
+ call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
+ ])
mock_func.reset_mock()
mock__start_task.reset_mock()
- output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
- self.assertEqual(expected_output, output)
- mock_func.assert_called_once_with(*args)
- mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
+ self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
+ mock_func.assert_has_calls(num * [call(*args)])
+ mock__start_task.assert_has_awaits(num * [
+ call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
+ ])
- @patch.object(pool.TaskPool, '_apply_one')
- async def test_apply(self, mock__apply_one: AsyncMock):
- mock__apply_one.return_value = mock_id = 67890
- mock_func, num = MagicMock(), 3
+ @patch.object(pool, 'create_task')
+ @patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
+ @patch.object(pool, 'TaskGroupRegister')
+ @patch.object(pool.TaskPool, '_generate_group_name')
+ @patch.object(pool.BaseTaskPool, '_check_start')
+ async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
+ mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
+ mock__generate_group_name.return_value = generated_name = 'name 123'
+ mock_group_reg = set_up_mock_group_register(mock_reg_cls)
+ mock__apply_num.return_value = mock_apply_coroutine = object()
+ mock_task_future = AsyncMock()
+ mock_create_task.return_value = mock_task_future()
+ mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
- expected_output = num * [mock_id]
- output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
- self.assertEqual(expected_output, output)
- mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
+ self.task_pool._task_groups = {}
- async def test__queue_producer(self):
+ def check_assertions(_group_name, _output):
+ self.assertEqual(_group_name, _output)
+ mock__check_start.assert_called_once_with(function=mock_func)
+ self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
+ mock_group_reg.__aenter__.assert_awaited_once_with()
+ mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
+ mock_create_task.assert_called_once_with(mock_apply_coroutine)
+ mock_group_reg.__aexit__.assert_awaited_once()
+ mock_task_future.assert_awaited_once_with()
+
+ output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
+ check_assertions(group_name, output)
+ mock__generate_group_name.assert_not_called()
+
+ mock__check_start.reset_mock()
+ self.task_pool._task_groups.clear()
+ mock_group_reg.__aenter__.reset_mock()
+ mock__apply_num.reset_mock()
+ mock_create_task.reset_mock()
+ mock_group_reg.__aexit__.reset_mock()
+ mock_task_future = AsyncMock()
+ mock_create_task.return_value = mock_task_future()
+
+ output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
+ check_assertions(generated_name, output)
+ mock__generate_group_name.assert_called_once_with('apply', mock_func)
+
+ @patch.object(pool, 'Queue')
+ async def test__queue_producer(self, mock_queue_cls: MagicMock):
mock_put = AsyncMock()
- mock_q = MagicMock(put=mock_put)
- args = (FOO, BAR, 123)
- assert not self.task_pool._interrupt_flag.is_set()
- self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
- mock_put.assert_has_awaits([call(arg) for arg in args])
+ mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
+ item1, item2, item3 = FOO, 420, 69
+ arg_iter = iter([item1, item2, item3])
+ self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
+ mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
+ with self.assertRaises(StopIteration):
+ next(arg_iter)
+
mock_put.reset_mock()
- self.task_pool._interrupt_flag.set()
- self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
- mock_put.assert_not_awaited()
- @patch.object(pool, 'partial')
- @patch.object(pool, 'star_function')
- @patch.object(pool.TaskPool, '_start_task')
- async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
- mock_partial: MagicMock):
- mock_partial.return_value = queue_callback = 'not really'
- mock_star_function.return_value = awaitable = 'totally an awaitable'
- q, arg = Queue(), 420.69
- q.put_nowait(arg)
- mock_func, stars = MagicMock(), 3
- mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
- self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
- self.assertTrue(q.empty())
- mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
- end_callback=queue_callback, cancel_callback=cancel_cb)
- mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
- mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
- q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars,
- end_callback=end_cb, cancel_callback=cancel_cb)
- mock__start_task.reset_mock()
- mock_star_function.reset_mock()
- mock_partial.reset_mock()
-
- self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
- self.assertTrue(q.empty())
- mock__start_task.assert_not_awaited()
- mock_star_function.assert_not_called()
- mock_partial.assert_not_called()
+ mock_put.side_effect = [CancelledError, None]
+ arg_iter = iter([item1, item2, item3])
+ mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
+ self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
+ mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
+ mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
+ mock_queue.item_processed.assert_has_calls([call(), call()])
+ self.assertListEqual([item2, item3], list(arg_iter))
@patch.object(pool, 'execute_optional')
- @patch.object(pool.TaskPool, '_queue_consumer')
- async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
- task_id, mock_q = 420, MagicMock()
- mock_func, stars = MagicMock(), 3
- mock_wait = AsyncMock()
- mock_flag = MagicMock(wait=mock_wait)
- end_cb, cancel_cb = MagicMock(), MagicMock()
- self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars,
- end_callback=end_cb, cancel_callback=cancel_cb))
- mock_wait.assert_awaited_once_with()
- mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars,
- end_callback=end_cb, cancel_callback=cancel_cb)
- mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
+ async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
+ semaphore, mock_end_cb = Semaphore(1), MagicMock()
+ wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb)
+ task_id = 1234
+ await wrapped(task_id)
+ self.assertEqual(2, semaphore._value)
+ mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,))
+
+ @patch.object(pool, 'star_function')
+ @patch.object(pool.TaskPool, '_start_task')
+ @patch.object(pool, 'Semaphore')
+ @patch.object(pool.TaskPool, '_get_map_end_callback')
+ async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
+ mock__start_task: AsyncMock, mock_star_function: MagicMock):
+ mock__get_map_end_callback.return_value = map_cb = MagicMock()
+ mock_semaphore_cls.return_value = semaphore = Semaphore(3)
+ mock_star_function.return_value = awaitable = 'totally an awaitable'
+ arg1, arg2 = 123456789, 'function argument'
+ mock_q_maxsize = 3
+ mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
+ __aexit__=AsyncMock(), maxsize=mock_q_maxsize)
+ group_name, mock_func, stars = 'whatever', MagicMock(), 3
+ end_cb, cancel_cb = MagicMock(), MagicMock()
+ self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
+ self.assertTrue(semaphore.locked())
+ mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
+ mock__start_task.assert_has_awaits(2 * [
+ call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
+ ])
+ mock_star_function.assert_has_calls([
+ call(mock_func, arg1, arg_stars=stars),
+ call(mock_func, arg2, arg_stars=stars)
+ ])
- @patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
- @patch.object(pool, 'join_queue', new_callable=MagicMock)
+ @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
- async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
- mock_create_task: MagicMock, mock_iter: MagicMock):
- args, num_tasks = (FOO, BAR, 1, 2, 3), 2
- mock_join_queue.return_value = mock_join = 'awaitable'
- mock_iter.return_value = args_iter = iter(args)
- mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
- output_q = self.task_pool._set_up_args_queue(args, num_tasks)
- self.assertIsInstance(output_q, Queue)
- self.assertEqual(num_tasks, output_q.qsize())
- for arg in args[:num_tasks]:
- self.assertEqual(arg, output_q.get_nowait())
- self.assertTrue(output_q.empty())
- for arg in args[num_tasks:]:
- self.assertEqual(arg, next(args_iter))
- with self.assertRaises(StopIteration):
- next(args_iter)
- self.assertListEqual([mock_join], self.task_pool._before_gathering)
- mock_join_queue.assert_called_once_with(output_q)
- mock__queue_producer.assert_called_once_with(output_q, args_iter)
- mock_create_task.assert_called_once_with(mock_producer_coro)
+ @patch.object(pool, 'join_queue', new_callable=MagicMock)
+ @patch.object(pool, 'Queue')
+ @patch.object(pool, 'TaskGroupRegister')
+ @patch.object(pool.BaseTaskPool, '_check_start')
+ async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
+ mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
+ mock_create_task: MagicMock):
+ mock_group_reg = set_up_mock_group_register(mock_reg_cls)
+ mock_queue_cls.return_value = mock_q = MagicMock()
+ mock_join_queue.return_value = fake_join = object()
+ mock__queue_producer.return_value = fake_producer = object()
+ mock__queue_consumer.return_value = fake_consumer = object()
+ fake_task1, fake_task2 = object(), object()
+ mock_create_task.side_effect = [fake_task1, fake_task2]
- self.task_pool._before_gathering.clear()
- mock_join_queue.reset_mock()
- mock__queue_producer.reset_mock()
- mock_create_task.reset_mock()
-
- num_tasks = 6
- mock_iter.return_value = args_iter = iter(args)
- output_q = self.task_pool._set_up_args_queue(args, num_tasks)
- self.assertIsInstance(output_q, Queue)
- self.assertEqual(len(args), output_q.qsize())
- for arg in args:
- self.assertEqual(arg, output_q.get_nowait())
- self.assertTrue(output_q.empty())
- with self.assertRaises(StopIteration):
- next(args_iter)
- self.assertListEqual([mock_join], self.task_pool._before_gathering)
- mock_join_queue.assert_called_once_with(output_q)
- mock__queue_producer.assert_not_called()
- mock_create_task.assert_not_called()
-
- @patch.object(pool, 'Event')
- @patch.object(pool.TaskPool, '_queue_consumer')
- @patch.object(pool.TaskPool, '_set_up_args_queue')
- async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
- mock_event_cls: MagicMock):
- qsize = 4
- mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
- mock_flag_set = MagicMock()
- mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
-
- mock_func, stars = MagicMock(), 3
- args_iter, group_size = (FOO, BAR, 1, 2, 3), 2
+ group_name, group_size = 'onetwothree', 0
+ func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
end_cb, cancel_cb = MagicMock(), MagicMock()
- self.task_pool._locked = False
- with self.assertRaises(exceptions.PoolIsLocked):
- await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)
- mock__set_up_args_queue.assert_not_called()
- mock__queue_consumer.assert_not_awaited()
- mock_flag_set.assert_not_called()
+ with self.assertRaises(ValueError):
+ await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
+ mock__check_start.assert_called_once_with(function=func)
- self.task_pool._locked = True
- self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb))
- mock__set_up_args_queue.assert_called_once_with(args_iter, group_size)
- mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
- end_callback=end_cb, cancel_callback=cancel_cb)])
- mock_flag_set.assert_called_once_with()
+ mock__check_start.reset_mock()
+
+ group_size = 1234
+ self.task_pool._task_groups = {group_name: MagicMock()}
+
+ with self.assertRaises(exceptions.InvalidGroupName):
+ await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
+ mock__check_start.assert_called_once_with(function=func)
+
+ mock__check_start.reset_mock()
+
+ self.task_pool._task_groups.clear()
+ self.task_pool._before_gathering = []
+
+ self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
+ mock__check_start.assert_called_once_with(function=func)
+ mock_reg_cls.assert_called_once_with()
+ self.task_pool._task_groups[group_name] = mock_group_reg
+ mock_group_reg.__aenter__.assert_awaited_once_with()
+ mock_queue_cls.assert_called_once_with(maxsize=group_size)
+ mock_join_queue.assert_called_once_with(mock_q)
+ self.assertListEqual([fake_join], self.task_pool._before_gathering)
+ mock__queue_producer.assert_called_once()
+ mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
+ mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
+ self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
+ mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.TaskPool, '_map')
- async def test_map(self, mock__map: AsyncMock):
+ @patch.object(pool.TaskPool, '_generate_group_name')
+ async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
+ mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
- arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2
+ arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
- self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb))
- mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size,
+ output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
+ self.assertEqual(group_name, output)
+ mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_not_called()
+
+ mock__map.reset_mock()
+ output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
+ self.assertEqual(generated_name, output)
+ mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
+ end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_called_once_with('map', mock_func)
@patch.object(pool.TaskPool, '_map')
- async def test_starmap(self, mock__map: AsyncMock):
+ @patch.object(pool.TaskPool, '_generate_group_name')
+ async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
+ mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
- args_iter, group_size = ([FOO], [BAR]), 2
+ args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
- self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb))
- mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size,
+ output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
+ self.assertEqual(group_name, output)
+ mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_not_called()
+
+ mock__map.reset_mock()
+ output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
+ self.assertEqual(generated_name, output)
+ mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
+ end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_called_once_with('starmap', mock_func)
@patch.object(pool.TaskPool, '_map')
- async def test_doublestarmap(self, mock__map: AsyncMock):
+ @patch.object(pool.TaskPool, '_generate_group_name')
+ async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
+ mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
- kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2
+ kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
- self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb))
- mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size,
+ output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
+ self.assertEqual(group_name, output)
+ mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_not_called()
+
+ mock__map.reset_mock()
+ output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
+ self.assertEqual(generated_name, output)
+ mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
+ end_callback=end_cb, cancel_callback=cancel_cb)
+ mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
class SimpleTaskPoolTestCase(CommonTestCase):
@@ -667,7 +806,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def test_stop(self, mock_cancel: MagicMock):
num = 2
id1, id2, id3 = 5, 6, 7
- self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR}
+ self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR}
output = self.task_pool.stop(num)
expected_output = [id3, id2]
self.assertEqual(expected_output, output)
@@ -689,3 +828,10 @@ class SimpleTaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output)
mock_num_running.assert_called_once_with()
mock_stop.assert_called_once_with(num)
+
+
+def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock:
+ mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock()
+ mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit,
+ add=mock_grp_add)
+ return mock_group_reg
diff --git a/usage/USAGE.md b/usage/USAGE.md
index 33272b4..741bbb1 100644
--- a/usage/USAGE.md
+++ b/usage/USAGE.md
@@ -28,7 +28,7 @@ async def work(n: int) -> None:
"""
for i in range(n):
await asyncio.sleep(1)
- print("did", i)
+ print("> did", i)
async def main() -> None:
@@ -39,7 +39,7 @@ async def main() -> None:
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.lock() # required for the last line
- await pool.gather() # awaits all tasks, then flushes the pool
+ await pool.gather_and_close() # awaits all tasks, then flushes the pool
if __name__ == '__main__':
@@ -52,29 +52,29 @@ SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0
Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2
-did 0
-did 0
-did 0
+> did 0
+> did 0
+> did 0
Started SimpleTaskPool-0_Task-3
-did 1
-did 1
-did 1
-did 0
+> did 1
+> did 1
+> did 1
+> did 0
+> did 2
+> did 2
SimpleTaskPool-0 is locked!
-Cancelling SimpleTaskPool-0_Task-3 ...
-Cancelled SimpleTaskPool-0_Task-3
-Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2
-did 2
-did 2
-did 3
-did 3
+Cancelling SimpleTaskPool-0_Task-3 ...
+Cancelled SimpleTaskPool-0_Task-3
+Ended SimpleTaskPool-0_Task-3
+> did 3
+> did 3
Ended SimpleTaskPool-0_Task-0
Ended SimpleTaskPool-0_Task-1
-did 4
-did 4
+> did 4
+> did 4
```
## Advanced example for `TaskPool`
@@ -101,21 +101,21 @@ async def work(start: int, stop: int, step: int = 1) -> None:
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
for i in range(start, stop, step):
await asyncio.sleep(1)
- print("work with", i)
+ print("> work with", i)
async def other_work(a: int, b: int) -> None:
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
for i in range(a, b):
await asyncio.sleep(0.5)
- print("other_work with", i)
+ print("> other_work with", i)
async def main() -> None:
# Initialize a new task pool instance and limit its size to 3 tasks.
pool = TaskPool(3)
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
- print("Called `apply`")
+ print("> Called `apply`")
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
# Let the tasks work for a bit.
await asyncio.sleep(1.5)
@@ -124,20 +124,18 @@ async def main() -> None:
# Since we set our pool size to 3, and already have two tasks working within the pool,
# only the first one of these will start immediately (and receive ID 2).
# The second one will start (with ID 3), only once there is room in the pool,
- # which -- in this example -- will be the case after ID 2 ends;
- # until then the `starmap` method call **will block**!
+ # which -- in this example -- will be the case after ID 2 ends.
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
- # **only** once there is room in the pool **and** no more than one of these last four tasks is running.
+ # **only** once there is room in the pool **and** no more than one other task of these new ones is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
- print("Calling `starmap`...")
await pool.starmap(other_work, args_list, group_size=2)
- print("`starmap` returned")
+ print("> Called `starmap`")
# Now we lock the pool, so that we can safely await all our tasks.
pool.lock()
# Finally, we block, until all tasks have ended.
- print("Called `gather`")
- await pool.gather()
- print("Done.")
+ print("> Calling `gather_and_close`...")
+ await pool.gather_and_close()
+ print("> Done.")
if __name__ == '__main__':
@@ -152,82 +150,81 @@ Additional comments for the output are provided with `<---` next to the output l
TaskPool-0 initialized
Started TaskPool-0_Task-0
Started TaskPool-0_Task-1
-Called `apply`
-work with 100
-work with 100
-Calling `starmap`... <--- notice that this blocks as expected
-Started TaskPool-0_Task-2
-work with 110
-work with 110
-other_work with 0
-other_work with 1
-work with 120
-work with 120
-other_work with 2
-other_work with 3
-work with 130
-work with 130
-other_work with 4
-other_work with 5
-work with 140
-work with 140
-other_work with 6
-other_work with 7
-work with 150
-work with 150
-other_work with 8
-Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
+> Called `apply`
+> work with 100
+> work with 100
+> Called `starmap` <--- notice that this immediately returns, even before Task-2 is started
+> Calling `gather_and_close`... <--- this blocks `main()` until all tasks have ended
TaskPool-0 is locked!
+Started TaskPool-0_Task-2 <--- at this point the pool is full
+> work with 110
+> work with 110
+> other_work with 0
+> other_work with 1
+> work with 120
+> work with 120
+> other_work with 2
+> other_work with 3
+> work with 130
+> work with 130
+> other_work with 4
+> other_work with 5
+> work with 140
+> work with 140
+> other_work with 6
+> other_work with 7
+> work with 150
+> work with 150
+> other_work with 8
+Ended TaskPool-0_Task-2 <--- this frees up room for one more task from `starmap`
Started TaskPool-0_Task-3
-other_work with 9
-`starmap` returned
-Called `gather` <--- now this will block `main()` until all tasks have ended
-work with 160
-work with 160
-other_work with 10
-other_work with 11
-work with 170
-work with 170
-other_work with 12
-other_work with 13
-work with 180
-work with 180
-other_work with 14
-other_work with 15
+> other_work with 9
+> work with 160
+> work with 160
+> other_work with 10
+> other_work with 11
+> work with 170
+> work with 170
+> other_work with 12
+> other_work with 13
+> work with 180
+> work with 180
+> other_work with 14
+> other_work with 15
Ended TaskPool-0_Task-0
-Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start
-Started TaskPool-0_Task-4
-work with 190
-work with 190
-other_work with 16
-other_work with 20
-other_work with 17
-other_work with 21
-other_work with 18
-other_work with 22
-other_work with 19
-Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start
+Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
+Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
+> work with 190
+> work with 190
+> other_work with 16
+> other_work with 17
+> other_work with 20
+> other_work with 18
+> other_work with 21
+Ended TaskPool-0_Task-3 <--- now that only Task-4 of the group remains, Task-5 starts
Started TaskPool-0_Task-5
-other_work with 23
-other_work with 30
-other_work with 24
-other_work with 31
-other_work with 25
-other_work with 32
-other_work with 26
-other_work with 33
-other_work with 27
-other_work with 34
-other_work with 28
-other_work with 35
+> other_work with 19
+> other_work with 22
+> other_work with 23
+> other_work with 30
+> other_work with 24
+> other_work with 31
+> other_work with 25
+> other_work with 32
+> other_work with 26
+> other_work with 33
+> other_work with 27
+> other_work with 34
+> other_work with 28
+> other_work with 35
+> other_work with 29
+> other_work with 36
Ended TaskPool-0_Task-4
-other_work with 29
-other_work with 36
-other_work with 37
-other_work with 38
-other_work with 39
-Done.
+> other_work with 37
+> other_work with 38
+> other_work with 39
Ended TaskPool-0_Task-5
+> Done.
```
© 2022 Daniil Fajnberg
diff --git a/usage/example_server.py b/usage/example_server.py
index e5649ce..f4a30e0 100644
--- a/usage/example_server.py
+++ b/usage/example_server.py
@@ -74,12 +74,12 @@ async def main() -> None:
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.
- pool.stop_all()
pool.lock()
+ pool.stop_all()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values.
- await pool.gather(return_exceptions=True)
+ await pool.gather_and_close(return_exceptions=True)
await control_server_task