diff --git a/README.md b/README.md index 5cb6ffd..65e7751 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,19 @@ If you need control over a task pool at runtime, you can launch an asynchronous ## Usage Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form: + ```python from asyncio_taskpool import SimpleTaskPool + ... + + async def work(foo, bar): ... + + ... + + async def main(): pool = SimpleTaskPool(work, args=('xyz', 420)) await pool.start(5) @@ -27,11 +35,11 @@ async def main(): pool.stop(3) ... pool.lock() - await pool.gather() + await pool.gather_and_close() ... ``` -Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) +Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md). diff --git a/setup.cfg b/setup.cfg index 0e7eb76..fc73473 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.3.5 +version = 0.4.0 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/__main__.py b/src/asyncio_taskpool/__main__.py index b997fdb..9e8c59b 100644 --- a/src/asyncio_taskpool/__main__.py +++ b/src/asyncio_taskpool/__main__.py @@ -54,10 +54,10 @@ def parse_cli() -> Dict[str, Any]: async def main(): kwargs = parse_cli() if kwargs[CONN_TYPE] == UNIX: - client = UnixControlClient(path=kwargs[SOCKET_PATH]) + client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) elif kwargs[CONN_TYPE] == TCP: # TODO: Implement the TCP client class - client = UnixControlClient(path=kwargs[SOCKET_PATH]) + client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) else: print("Invalid connection type", file=sys.stderr) sys.exit(2) diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index c4f0869..9d2b749 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -21,6 +21,9 @@ Constants used by more than one module in the package. PACKAGE_NAME = 'asyncio_taskpool' +DEFAULT_TASK_GROUP = '' +DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S' + CLIENT_EXIT = 'exit' SESSION_MSG_BYTES = 1024 * 100 diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index dbc28df..6f0c7d5 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -23,6 +23,10 @@ class PoolException(Exception): pass +class PoolIsClosed(PoolException): + pass + + class PoolIsLocked(PoolException): pass @@ -43,6 +47,10 @@ class InvalidTaskID(PoolException): pass +class InvalidGroupName(PoolException): + pass + + class PoolStillUnlocked(PoolException): pass diff --git a/src/asyncio_taskpool/group_register.py b/src/asyncio_taskpool/group_register.py new file mode 100644 index 0000000..81b7b81 --- /dev/null +++ b/src/asyncio_taskpool/group_register.py @@ -0,0 +1,75 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +This module contains the definition of the `TaskGroupRegister` class. +""" + + +from asyncio.locks import Lock +from collections.abc import MutableSet +from typing import Iterator, Set + + +class TaskGroupRegister(MutableSet): + """ + This class combines the interface of a regular `set` with that of the `asyncio.Lock`. + + It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for + preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of + tasks, as well as before starting a task within the group. + """ + + def __init__(self, *task_ids: int) -> None: + self._ids: Set[int] = set(task_ids) + self._lock = Lock() + + def __contains__(self, task_id: int) -> bool: + """Abstract method for the `MutableSet` base class.""" + return task_id in self._ids + + def __iter__(self) -> Iterator[int]: + """Abstract method for the `MutableSet` base class.""" + return iter(self._ids) + + def __len__(self) -> int: + """Abstract method for the `MutableSet` base class.""" + return len(self._ids) + + def add(self, task_id: int) -> None: + """Abstract method for the `MutableSet` base class.""" + self._ids.add(task_id) + + def discard(self, task_id: int) -> None: + """Abstract method for the `MutableSet` base class.""" + self._ids.discard(task_id) + + async def acquire(self) -> bool: + """Wrapper around the lock's `acquire()` method.""" + return await self._lock.acquire() + + def release(self) -> None: + """Wrapper around the lock's `release()` method.""" + self._lock.release() + + async def __aenter__(self) -> None: + """Provides the asynchronous context manager syntax `async with ... :` when using the lock.""" + await self._lock.acquire() + return None + + async def __aexit__(self, exc_type, exc, tb) -> None: + """Provides the asynchronous context manager syntax `async with ... :` when using the lock.""" + self._lock.release() diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index ed3bfb0..9f0abfc 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -32,19 +32,22 @@ For further details about the classes check their respective docstrings. import logging -from asyncio import gather from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.exceptions import CancelledError -from asyncio.locks import Event, Semaphore -from asyncio.queues import Queue, QueueEmpty -from asyncio.tasks import Task, create_task -from functools import partial +from asyncio.locks import Semaphore +from asyncio.queues import QueueEmpty +from asyncio.tasks import Task, create_task, gather +from contextlib import suppress +from datetime import datetime from math import inf -from typing import Any, Awaitable, Dict, Iterable, Iterator, List +from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set from . import exceptions +from .constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT +from .group_register import TaskGroupRegister from .helpers import execute_optional, star_function, join_queue -from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT +from .queue_context import Queue +from .types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB log = logging.getLogger(__name__) @@ -62,19 +65,30 @@ class BaseTaskPool: def __init__(self, pool_size: int = inf, name: str = None) -> None: """Initializes the necessary internal attributes and adds the new pool to the general pools list.""" - self._enough_room: Semaphore = Semaphore() - self.pool_size = pool_size + # Initialize a counter for the total number of tasks started through the pool and one for the total number of + # tasks cancelled through the pool. + self._num_started: int = 0 + self._num_cancellations: int = 0 + + # Initialize flags; immutably set the name. self._locked: bool = False - self._counter: int = 0 - self._running: Dict[int, Task] = {} - self._cancelled: Dict[int, Task] = {} - self._ended: Dict[int, Task] = {} - self._num_cancelled: int = 0 - self._num_ended: int = 0 - self._idx: int = self._add_pool(self) + self._closed: bool = False self._name: str = name + + # The following three dictionaries are the actual containers of the tasks controlled by the pool. + self._tasks_running: Dict[int, Task] = {} + self._tasks_cancelled: Dict[int, Task] = {} + self._tasks_ended: Dict[int, Task] = {} + + # These next three attributes act as synchronisation primitives necessary for managing the pool. self._before_gathering: List[Awaitable] = [] - self._interrupt_flag: Event = Event() + self._enough_room: Semaphore = Semaphore() + self._task_groups: Dict[str, TaskGroupRegister[int]] = {} + + # Finish with method/functions calls that add the pool to the internal list of pools, set its initial size, + # and issue a log message. + self._idx: int = self._add_pool(self) + self.pool_size = pool_size log.debug("%s initialized", str(self)) def __str__(self) -> str: @@ -125,34 +139,36 @@ class BaseTaskPool: def num_running(self) -> int: """ Returns the number of tasks in the pool that are (at that moment) still running. - At the moment a task's `end_callback` is fired, it is no longer considered to be running. + + At the moment a task's `end_callback` or `cancel_callback` is fired, it is no longer considered running. """ - return len(self._running) + return len(self._tasks_running) @property - def num_cancelled(self) -> int: + def num_cancellations(self) -> int: """ Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment). - At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running. + + At the moment a task's `cancel_callback` is fired, this counts as a cancellation, and the task is then + considered cancelled (instead of running) until its `end_callback` is fired. """ - return self._num_cancelled + return self._num_cancellations @property def num_ended(self) -> int: """ Returns the number of tasks started through the pool that have stopped running (up until that moment). - At the moment a task's `end_callback` is fired, it is considered ended. + + At the moment a task's `end_callback` is fired, it is considered ended and no longer running (or cancelled). When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned, does it then actually end. """ - return self._num_ended + return len(self._tasks_ended) @property def num_finished(self) -> int: - """ - Returns the number of tasks in the pool that have actually finished running (without having been cancelled). - """ - return self._num_ended - self._num_cancelled + len(self._cancelled) + """Returns the number of tasks in the pool that have finished running (without having been cancelled).""" + return len(self._tasks_ended) - self._num_cancellations + len(self._tasks_cancelled) @property def is_full(self) -> bool: @@ -162,14 +178,61 @@ class BaseTaskPool: """ return self._enough_room.locked() - # TODO: Consider adding task group names + def get_task_group_ids(self, group_name: str) -> Set[int]: + """ + Returns the set of IDs of all tasks in the specified group. + + Args: + group_name: Must be a name of a task group that exists within the pool. + + Returns: + Set of integers representing the task IDs belonging to the specified group. + + Raises: + `InvalidGroupName` if no task group named `group_name` exists in the pool. + """ + try: + return set(self._task_groups[group_name]) + except KeyError: + raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") + + def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None, + ignore_lock: bool = False) -> None: + """ + Checks necessary conditions for starting a task (group) in the pool. + + Either something that is expected to be a coroutine (i.e. awaitable) or something that is expected to be a + coroutine _function_ will be checked. + + Args: + awaitable: If this is passed, `function` must be `None`. + function: If this is passed, `awaitable` must be `None`. + ignore_lock (optional): If `True`, a locked pool will produce no error here. + + Raises: + `AssertionError` if both or neither of `awaitable` and `function` were passed. + `asyncio_taskpool.exceptions.PoolIsClosed` if the pool is closed. + `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a cor. / `function` not a cor. func. + `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`. + """ + assert (awaitable is None) != (function is None) + if awaitable and not iscoroutine(awaitable): + raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") + if function and not iscoroutinefunction(function): + raise exceptions.NotCoroutine(f"Not a coroutine function: {function}") + if self._closed: + raise exceptions.PoolIsClosed("You must use another pool") + if self._locked and not ignore_lock: + raise exceptions.PoolIsLocked("Cannot start new tasks") + def _task_name(self, task_id: int) -> str: """Returns a standardized name for a task with a specific `task_id`.""" return f'{self}_Task-{task_id}' - async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: + async def _task_cancellation(self, task_id: int, custom_callback: CancelCB = None) -> None: """ Universal callback to be run upon any task in the pool being cancelled. + Required for keeping track of running/cancelled tasks and proper logging. Args: @@ -180,14 +243,15 @@ class BaseTaskPool: It is run at the end of this function with the `task_id` as its only positional argument. """ log.debug("Cancelling %s ...", self._task_name(task_id)) - self._cancelled[task_id] = self._running.pop(task_id) - self._num_cancelled += 1 + self._tasks_cancelled[task_id] = self._tasks_running.pop(task_id) + self._num_cancellations += 1 log.debug("Cancelled %s", self._task_name(task_id)) await execute_optional(custom_callback, args=(task_id,)) - async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None: + async def _task_ending(self, task_id: int, custom_callback: EndCB = None) -> None: """ Universal callback to be run upon any task in the pool ending its work. + Required for keeping track of running/cancelled/ended tasks and proper logging. Also releases room in the task pool for potentially waiting tasks. @@ -199,19 +263,20 @@ class BaseTaskPool: It is run at the end of this function with the `task_id` as its only positional argument. """ try: - self._ended[task_id] = self._running.pop(task_id) + self._tasks_ended[task_id] = self._tasks_running.pop(task_id) except KeyError: - self._ended[task_id] = self._cancelled.pop(task_id) - self._num_ended += 1 + self._tasks_ended[task_id] = self._tasks_cancelled.pop(task_id) self._enough_room.release() log.info("Ended %s", self._task_name(task_id)) await execute_optional(custom_callback, args=(task_id,)) - async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, - cancel_callback: CancelCallbackT = None) -> Any: + async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCB = None, + cancel_callback: CancelCB = None) -> Any: """ - Universal wrapper around every task to be run in the pool. - Returns/raises whatever the wrapped coroutine does. + Universal wrapper around every task run in the pool that returns/raises whatever the wrapped coroutine does. + + Responsible for catching cancellation and awaiting the `_task_cancellation` callback, as well as for awaiting + the `_task_ending` callback, after the coroutine returns or raises an exception. Args: awaitable: @@ -233,16 +298,19 @@ class BaseTaskPool: finally: await self._task_ending(task_id, custom_callback=end_callback) - async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None, - cancel_callback: CancelCallbackT = None) -> int: + async def _start_task(self, awaitable: Awaitable, group_name: str = DEFAULT_TASK_GROUP, ignore_lock: bool = False, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> int: """ Starts a coroutine as a new task in the pool. - This method blocks, **only if** the pool is full. - Returns/raises whatever the wrapped coroutine does. + + This method can block for a significant amount of time, **only if** the pool is full. + Otherwise it merely needs to acquire the `TaskGroupRegister` lock, which should never be held for a long time. Args: awaitable: The actual coroutine to be run within the task pool. + group_name (optional): + Name of the task group to add the new task to; defaults to the `DEFAULT_TASK_GROUP` constant. ignore_lock (optional): If `True`, even if the pool is locked, the task will still be started. end_callback (optional): @@ -252,25 +320,20 @@ class BaseTaskPool: A callback to execute after cancellation of the task. It is run with the task's ID as its only positional argument. - Raises: - `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine. - `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`. + Returns: + The ID of the newly started task. """ - if not iscoroutine(awaitable): - raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") - if self._locked and not ignore_lock: - raise exceptions.PoolIsLocked("Cannot start new tasks") + self._check_start(awaitable=awaitable, ignore_lock=ignore_lock) await self._enough_room.acquire() - task_id = self._counter - self._counter += 1 - try: - self._running[task_id] = create_task( - self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), + group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) + async with group_reg: + task_id = self._num_started + self._num_started += 1 + group_reg.add(task_id) + self._tasks_running[task_id] = create_task( + coro=self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), name=self._task_name(task_id) ) - except Exception as e: - self._enough_room.release() - raise e return task_id def _get_running_task(self, task_id: int) -> Task: @@ -286,11 +349,11 @@ class BaseTaskPool: `asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool. """ try: - return self._running[task_id] + return self._tasks_running[task_id] except KeyError: - if self._cancelled.get(task_id): - raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled") - if self._ended.get(task_id): + if self._tasks_cancelled.get(task_id): + raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has been cancelled") + if self._tasks_ended.get(task_id): raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") @@ -303,91 +366,111 @@ class BaseTaskPool: - `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled. - `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently). - `InvalidTaskID` if any of the `task_ids` is not known to the pool. - Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten. + Note that once a pool has been flushed (see below), IDs of tasks that have ended previously will be forgotten. Args: - task_ids: - Arbitrary number of integers. Each must be an ID of a task still running within the pool. - msg (optional): - Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + task_ids: Arbitrary number of integers. Each must be an ID of a task still running within the pool. + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. """ tasks = [self._get_running_task(task_id) for task_id in task_ids] for task in tasks: task.cancel(msg=msg) - def cancel_all(self, msg: str = None) -> None: + def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None: + """ + Removes all tasks from the specified group and cancels them, if they are still running. + + Args: + group_name: The name of the group of tasks that shall be cancelled. + group_reg: The task group register object containing the task IDs. + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + """ + while group_reg: + try: + self._tasks_running[group_reg.pop()].cancel(msg=msg) + except KeyError: + continue + log.debug("%s cancelled tasks from group %s", str(self), group_name) + + async def cancel_group(self, group_name: str, msg: str = None) -> None: + """ + Cancels an entire group of tasks. The task group is subsequently forgotten by the pool. + + Args: + group_name: The name of the group of tasks that shall be cancelled. + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + + Raises: + `InvalidGroupName` if no task group named `group_name` exists in the pool. + """ + log.debug("%s cancelling tasks in group %s", str(self), group_name) + try: + group_reg = self._task_groups.pop(group_name) + except KeyError: + raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") + async with group_reg: + self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) + log.debug("%s forgot task group %s", str(self), group_name) + + async def cancel_all(self, msg: str = None) -> None: """ Cancels all tasks still running within the pool. - Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks. - This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller - than the number of arguments from `args_iter`. - In this case, those already running will be cancelled, while the following will **never even start**. - Args: - msg (optional): - Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. """ log.warning("%s cancelling all tasks!", str(self)) - self._interrupt_flag.set() - for task in self._running.values(): - task.cancel(msg=msg) + while self._task_groups: + group_name, group_reg = self._task_groups.popitem() + async with group_reg: + self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) async def flush(self, return_exceptions: bool = False): """ - Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks. + Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks. + + This method exists mainly to free up memory of unneeded `Task` objects. + This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks registered for the tasks block. Args: return_exceptions (optional): Passed directly into `gather`. """ - results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) - self._ended.clear() - self._cancelled.clear() - if self._interrupt_flag.is_set(): - self._interrupt_flag.clear() - return results + await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), return_exceptions=return_exceptions) + self._tasks_ended.clear() + self._tasks_cancelled.clear() - async def gather(self, return_exceptions: bool = False): + async def gather_and_close(self, return_exceptions: bool = False): """ - Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks. + Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool. The `lock()` method must have been called prior to this. - Note that there may be an unknown number of coroutine functions already "queued" to be run as tasks. - This can happen, if for example the `TaskPool.map` method was called with `group_size` set to a number smaller - than the number of arguments from `args_iter`. - In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially - blocking this method. Otherwise it will wait until they all have started. - - This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the - callbacks registered for a task blocks. + This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the + callbacks registered for a task blocks for whatever reason. Args: return_exceptions (optional): Passed directly into `gather`. Raises: - `asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet. + `PoolStillUnlocked` if the pool has not been locked yet. """ if not self._locked: raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered") await gather(*self._before_gathering) - results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), - return_exceptions=return_exceptions) - self._ended.clear() - self._cancelled.clear() - self._running.clear() + await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(), + return_exceptions=return_exceptions) + self._tasks_ended.clear() + self._tasks_cancelled.clear() + self._tasks_running.clear() self._before_gathering.clear() - if self._interrupt_flag.is_set(): - self._interrupt_flag.clear() - return results + self._closed = True class TaskPool(BaseTaskPool): """ - General task pool class. - Attempts to somewhat emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib. + General task pool class. Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib. A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function. Tasks in the pool can all belong to the same coroutine function, @@ -399,41 +482,187 @@ class TaskPool(BaseTaskPool): Adding tasks blocks **only if** the pool is full at that moment. """ - async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: + _QUEUE_END_SENTINEL = object() + + def __init__(self, pool_size: int = inf, name: str = None) -> None: + super().__init__(pool_size=pool_size, name=name) + # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of + # meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks. + self._group_meta_tasks_running: Dict[str, Set[Task]] = {} + self._meta_tasks_cancelled: Set[Task] = set() + + def _cancel_group_meta_tasks(self, group_name: str) -> None: + """Cancels and forgets all meta tasks associated with the task group named `group_name`.""" + try: + meta_tasks = self._group_meta_tasks_running.pop(group_name) + except KeyError: + return + for meta_task in meta_tasks: + meta_task.cancel() + self._meta_tasks_cancelled.update(meta_tasks) + log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name) + + def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None: + self._cancel_group_meta_tasks(group_name) + super()._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) + + async def cancel_group(self, group_name: str, msg: str = None) -> None: + """ + Cancels an entire group of tasks. The task group is subsequently forgotten by the pool. + + If any methods such as `map()` launched meta tasks belonging to that group, these meta tasks are cancelled + before the actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will + **never even start**. In the case of `map()` this would mean that the `arg_iter` may be abandoned before it + was fully consumed (if that is even possible). + + Args: + group_name: The name of the group of tasks (and meta tasks) that shall be cancelled. + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + + Raises: + `InvalidGroupName` if no task group named `group_name` exists in the pool. + """ + await super().cancel_group(group_name=group_name, msg=msg) + + async def cancel_all(self, msg: str = None) -> None: + """ + Cancels all tasks still running within the pool. (This includes all meta tasks.) + + If any methods such as `map()` launched meta tasks, these meta tasks are cancelled before the actual tasks are + cancelled. This means that any tasks "queued" to be started by a meta task will **never even start**. In the + case of `map()` this would mean that the `arg_iter` may be abandoned before it was fully consumed (if that is + even possible). + + Args: + msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. + """ + await super().cancel_all(msg=msg) + + def _pop_ended_meta_tasks(self) -> Set[Task]: + """ + Goes through all not-cancelled meta tasks, checks if they are done already, and returns those that are. + + The internal `_group_meta_tasks_running` dictionary is updated accordingly, i.e. after this method returns, the + values (sets) only contain meta tasks that were not done yet. In addition, names of groups that no longer + have any running meta tasks associated with them are removed from the keys. + """ + obsolete_keys, ended_meta_tasks = [], set() + for group_name in self._group_meta_tasks_running.keys(): + still_running = set() + while self._group_meta_tasks_running[group_name]: + meta_task = self._group_meta_tasks_running[group_name].pop() + if meta_task.done(): + ended_meta_tasks.add(meta_task) + else: + still_running.add(meta_task) + if still_running: + self._group_meta_tasks_running[group_name] = still_running + else: + obsolete_keys.append(group_name) + # If a group no longer has running meta tasks associated with, we can remove its name from the dictionary. + for group_name in obsolete_keys: + del self._group_meta_tasks_running[group_name] + return ended_meta_tasks + + async def flush(self, return_exceptions: bool = False): + """ + Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks. + + This method exists mainly to free up memory of unneeded `Task` objects. It also gets rid of unneeded meta tasks. + + This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the + callbacks registered for the tasks block. + + Args: + return_exceptions (optional): Passed directly into `gather`. + """ + await super().flush(return_exceptions=return_exceptions) + with suppress(CancelledError): + await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(), + return_exceptions=return_exceptions) + self._meta_tasks_cancelled.clear() + + async def gather_and_close(self, return_exceptions: bool = False): + """ + Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool. + + The `lock()` method must have been called prior to this. + + Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta + tasks launched my methods such as `map()`, which ends by itself, only once the `arg_iter` is fully consumed, + which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid + this, make sure to call `cancel_all()` prior to this. + + + This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of + the callbacks registered for a task blocks for whatever reason. + + Args: + return_exceptions (optional): Passed directly into `gather`. + + Raises: + `PoolStillUnlocked` if the pool has not been locked yet. + """ + await super().gather_and_close(return_exceptions=return_exceptions) + not_cancelled_meta_tasks = set() + while self._group_meta_tasks_running: + _, meta_tasks = self._group_meta_tasks_running.popitem() + not_cancelled_meta_tasks.update(meta_tasks) + with suppress(CancelledError): + await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions) + self._meta_tasks_cancelled.clear() + + @staticmethod + def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str: + """ + Creates a task group identifier that includes the current datetime. + + Args: + prefix: The start of the name; will be followed by an underscore. + coroutine_function: The function representing the task group. + + Returns: + The constructed 'prefix_function_datetime' string to name a task group. + """ + return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}' + + async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, + num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ Creates a coroutine with the supplied arguments and runs it as a new task in the pool. - This method blocks, **only if** the pool is full. + This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks. Args: + group_name: + Name of the task group to add the new task to. func: The coroutine function to be run as a task within the task pool. args (optional): The positional arguments to pass into the function call. kwargs (optional): The keyword-arguments to pass into the function call. + num (optional): + The number of tasks to spawn with the specified parameters. end_callback (optional): A callback to execute after the task has ended. It is run with the task's ID as its only positional argument. cancel_callback (optional): A callback to execute after cancellation of the task. It is run with the task's ID as its only positional argument. - - Returns: - The newly spawned task's ID within the pool. """ if kwargs is None: kwargs = {} - return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) + await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback, + cancel_callback=cancel_callback) for _ in range(num))) async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]: + group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool. - Each coroutine looks like `func(*args, **kwargs)`. - This method blocks, **only if** there is not enough room in the pool for the desired number of new tasks. + Each coroutine looks like `func(*args, **kwargs)`. All the new tasks are added to the same task group. + This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks. Args: func: @@ -444,6 +673,8 @@ class TaskPool(BaseTaskPool): The keyword-arguments to pass into each function call. num (optional): The number of tasks to spawn with the specified parameters. + group_name (optional): + Name of the task group to add the new tasks to. end_callback (optional): A callback to execute after a task has ended. It is run with the task's ID as its only positional argument. @@ -452,176 +683,129 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Returns: - The newly spawned tasks' IDs within the pool as a list of integers. + The name of the task group that the newly spawned tasks have been added to. Raises: + `PoolIsClosed` if the pool is closed. `NotCoroutine` if `func` is not a coroutine function. - `PoolIsLocked` if the pool has been locked already. + `PoolIsLocked` if the pool has been locked. """ - ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))) - # TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions - assert isinstance(ids, list) - return ids + self._check_start(function=func) + if group_name is None: + group_name = self._generate_group_name('apply', func) + group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) + async with group_reg: + task = create_task(self._apply_num(group_name, func, args, kwargs, num, end_callback, cancel_callback)) + await task + return group_name - async def _queue_producer(self, q: Queue, args_iter: Iterator[Any]) -> None: + @classmethod + async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None: """ Keeps the arguments queue from `_map()` full as long as the iterator has elements. - If the `_interrupt_flag` gets set, the loop ends prematurely. + + Intended to be run as a meta task of a specific group. Args: - q: - The queue of function arguments to consume for starting the next task. - args_iter: - The iterator of function arguments to put into the queue. - """ - for arg in args_iter: - if self._interrupt_flag.is_set(): - break - await q.put(arg) # This blocks as long as the queue is full. - - async def _queue_consumer(self, q: Queue, first_batch_started: Event, func: CoroutineFunc, arg_stars: int = 0, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: - """ - Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`. - Partially constructs the `_queue_callback` function with the same arguments. - - Args: - q: - The queue of function arguments to consume for starting the next task. - first_batch_started: - The event flag to wait for, before launching the next consumer. - It can only set by the `_map()` method, which happens after the first batch of task has been started. - func: - The coroutine function to use for spawning the tasks within the task pool. - arg_stars (optional): - Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2. - end_callback (optional): - The actual callback specified to execute after the task (and the next one) has ended. - It is run with the task's ID as its only positional argument. - cancel_callback (optional): - The callback that was specified to execute after cancellation of the task (and the next one). - It is run with the task's ID as its only positional argument. + arg_queue: The queue of function arguments to consume for starting a new task. + arg_iter: The iterator of function arguments to put into the queue. + group_name: Name of the task group associated with this producer. """ try: - arg = q.get_nowait() - except QueueEmpty: - return - try: - await self._start_task( - star_function(func, arg, arg_stars=arg_stars), - ignore_lock=True, - end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started, - func=func, arg_stars=arg_stars, end_callback=end_callback, - cancel_callback=cancel_callback), - cancel_callback=cancel_callback - ) + for arg in arg_iter: + await arg_queue.put(arg) # This blocks as long as the queue is full. + except CancelledError: + # This means that no more tasks are supposed to be created from this `_map()` call; + # thus, we can immediately drain the entire queue and forget about the rest of the arguments. + log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) + while True: + try: + arg_queue.get_nowait() + arg_queue.item_processed() + except QueueEmpty: + return finally: - q.task_done() + await arg_queue.put(cls._QUEUE_END_SENTINEL) - async def _queue_callback(self, task_id: int, q: Queue, first_batch_started: Event, func: CoroutineFunc, - arg_stars: int = 0, end_callback: EndCallbackT = None, - cancel_callback: CancelCallbackT = None) -> None: + @staticmethod + def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: + """Returns a wrapped `end_callback` for each `_queue_consumer()` task that will release the `map_semaphore`.""" + async def release_callback(task_id: int) -> None: + map_semaphore.release() + await execute_optional(actual_end_callback, args=(task_id,)) + return release_callback + + async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ - Wrapper around an end callback function passed into the `_map()` method. - Triggers the next `_queue_consumer` with the same arguments. + Consumes arguments from the queue from `_map()` and keeps a limited number of tasks working on them. + + The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before + a new task can be started, and which must be released when one of these tasks ends. + + Intended to be run as a meta task of a specific group. Args: - task_id: - The ID of the ending task. - q: - The queue of function arguments to consume for starting the next task. - first_batch_started: - The event flag to wait for, before launching the next consumer. - It can only set by the `_map()` method, which happens after the first batch of task has been started. + arg_queue: + The queue of function arguments to consume for starting a new task. + group_name: + Name of the associated task group; passed into the `_start_task()` method. func: - The coroutine function to use for spawning the tasks within the task pool. + The coroutine function to use for spawning the new tasks within the task pool. arg_stars (optional): - Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2. + Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2. end_callback (optional): The actual callback specified to execute after the task (and the next one) has ended. - It is run with the `task_id` as its only positional argument. + It is run with the task's ID as its only positional argument. cancel_callback (optional): The callback that was specified to execute after cancellation of the task (and the next one). - It is run with the `task_id` as its only positional argument. + It is run with the task's ID as its only positional argument. """ - await first_batch_started.wait() - await self._queue_consumer(q, first_batch_started, func, arg_stars, - end_callback=end_callback, cancel_callback=cancel_callback) - await execute_optional(end_callback, args=(task_id,)) + map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in `_map()` + release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) + while True: + # The following line blocks **only if** the number of running tasks spawned by this method has reached the + # specified maximum as determined in the `_map()` method. + await map_semaphore.acquire() + # We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called. + async with arg_queue as next_arg: + if next_arg is self._QUEUE_END_SENTINEL: + # The `_queue_producer()` either reached the last argument or was cancelled. + return + await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, + ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) - def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue: - """ - Helper function for `_map()`. - Takes the iterable of function arguments `args_iter` and adds up to `group_size` to a new `asyncio.Queue`. - The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned. - - If the iterable contains less than `group_size` elements, nothing else happens; otherwise the `_queue_producer` - is started as a separate task with the arguments queue and and iterator of the remaining arguments. - - Args: - args_iter: - The iterable of function arguments passed into `_map()` to use for creating the new tasks. - num_tasks: - The maximum number of the new tasks to run concurrently that was passed into `_map()`. - - Returns: - The newly created and filled arguments queue for spawning new tasks. - """ - # Setting the `maxsize` of the queue to `group_size` will ensure that no more than `group_size` tasks will run - # concurrently because the size of the queue is what will determine the number of immediately started tasks in - # the `_map()` method and each of those will only ever start (at most) one other task upon ending. - args_queue = Queue(maxsize=num_tasks) - self._before_gathering.append(join_queue(args_queue)) - args_iter = iter(args_iter) - try: - # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of - # tasks, which will be at most `group_size` (meaning the queue will be full). - for i in range(num_tasks): - args_queue.put_nowait(next(args_iter)) - except StopIteration: - # If we get here, this means that the number of elements in the arguments iterator was less than the - # specified `group_size`. Still, the number of tasks to start immediately will be the size of the queue. - # The `_queue_producer` won't be necessary, since we already put all the elements in the queue. - pass - else: - # There may be more elements in the arguments iterator, so we need the `_queue_producer`. - # It will have exclusive access to the `args_iter` from now on. - # Since the queue is full already, it will wait until one of the tasks in the first batch ends, - # before putting the next item in it. - create_task(self._queue_producer(args_queue, args_iter)) - return args_queue - - async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, group_size: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool. - Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `args_iter`. + Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`. + All the new tasks are added to the same task group. The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at - any given moment in time. Assuming the number of elements produced by `args_iter` is greater than `group_size`, - this method will block **only** until the first `group_size` tasks have been **started**, before returning. - (If the number of elements from `args_iter` is smaller than `group_size`, this method will return as soon as - all of them have been started.) + any given moment in time. As soon as one task from this group ends, it triggers the start of a new task + (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size + of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running + concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). - As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in - the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a - limit, this ensures that the number of tasks running concurrently as a result of this method call is always - equal to `group_size` (except for when `args_iter` is exhausted of course). - - Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. - - This method sets up an internal arguments queue which is continuously filled while consuming the `args_iter`. + This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. + Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the + aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does + not mean that any task was started or that any number of tasks will start soon, as this is solely determined by + the `pool_size` and the `group_size`. Args: + group_name: + Name of the task group to add the new tasks to. It must be a name that doesn't exist yet. + group_size: + The maximum number new tasks spawned by this method to run concurrently. func: The coroutine function to use for spawning the new tasks within the task pool. - args_iter: + arg_iter: The iterable of arguments; each element is to be passed into a `func` call when spawning a new task. - arg_stars (optional): + arg_stars: Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2. - group_size (optional): - The maximum number new tasks spawned by this method to run concurrently. Defaults to 1. end_callback (optional): A callback to execute after a task has ended. It is run with the task's ID as its only positional argument. @@ -630,44 +814,46 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Raises: - `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked. + `ValueError` if `group_size` is less than 1. + `asyncio_taskpool.exceptions.InvalidGroupName` if a group named `group_name` exists in the pool. """ - if not self._locked: - raise exceptions.PoolIsLocked("Cannot start new tasks") - args_queue = self._set_up_args_queue(args_iter, group_size) - # We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the - # `_queue_callback` triggered by one or more of them. - # This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more - # than one element, and the pool remains full until after the first task of the first batch ends. Then the - # callback might trigger the next `_queue_consumer` before this method can, which will keep it blocked. - first_batch_started = Event() - for _ in range(args_queue.qsize()): - # This is where blocking can occur, if the pool is full. - await self._queue_consumer(args_queue, first_batch_started, func, - arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) - # Now the callbacks can immediately trigger more tasks. - first_batch_started.set() + self._check_start(function=func) + if group_size < 1: + raise ValueError(f"Group size must be a positive integer.") + if group_name in self._task_groups.keys(): + raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") + self._task_groups[group_name] = group_reg = TaskGroupRegister() + async with group_reg: + # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the + # `_queue_producer()`; that way an argument + arg_queue = Queue(maxsize=group_size) + self._before_gathering.append(join_queue(arg_queue)) + meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) + # Start the producer and consumer meta tasks. + meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) + meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars, + end_callback, cancel_callback))) - async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method. Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool. Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. + All the new tasks are added to the same task group. The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at - any given moment in time. Assuming the number of elements produced by `arg_iter` is greater than `group_size`, - this method will block **only** until the first `group_size` tasks have been **started**, before returning. - (If the number of elements from `arg_iter` is smaller than `group_size`, this method will return as soon as - all of them have been started.) + any given moment in time. As soon as one task from this group ends, it triggers the start of a new task + (assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size + of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running + concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). - As soon as one task from this first batch ends, it triggers the start of a new task (assuming there is room in - the pool), which consumes the next element from the arguments iterable. If the size of the pool never imposes a - limit, this ensures that the number of tasks running concurrently as a result of this method call is always - equal to `group_size` (except for when `arg_iter` is exhausted of course). - - Thus, this method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. + This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. + Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the + aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does + not mean that any task was started or that any number of tasks will start soon, as this is solely determined by + the `pool_size` and the `group_size`. Args: func: @@ -676,6 +862,8 @@ class TaskPool(BaseTaskPool): The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task. group_size (optional): The maximum number new tasks spawned by this method to run concurrently. Defaults to 1. + group_name (optional): + Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet. end_callback (optional): A callback to execute after a task has ended. It is run with the task's ID as its only positional argument. @@ -683,32 +871,48 @@ class TaskPool(BaseTaskPool): A callback to execute after cancellation of a task. It is run with the task's ID as its only positional argument. + Returns: + The name of the task group that the newly spawned tasks will be added to. + Raises: - `PoolIsLocked` if the pool has been locked. + `PoolIsClosed` if the pool is closed. `NotCoroutine` if `func` is not a coroutine function. + `PoolIsLocked` if the pool has been locked. + `ValueError` if `group_size` is less than 1. + `InvalidGroupName` if a group named `group_name` exists in the pool. """ - await self._map(func, arg_iter, arg_stars=0, group_size=group_size, + if group_name is None: + group_name = self._generate_group_name('map', func) + await self._map(group_name, group_size, func, arg_iter, 0, end_callback=end_callback, cancel_callback=cancel_callback) + return group_name async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as positional arguments to the function. Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`. """ - await self._map(func, args_iter, arg_stars=1, group_size=group_size, + if group_name is None: + group_name = self._generate_group_name('starmap', func) + await self._map(group_name, group_size, func, args_iter, 1, end_callback=end_callback, cancel_callback=cancel_callback) + return group_name async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + group_name: str = None, end_callback: EndCB = None, + cancel_callback: CancelCB = None) -> str: """ Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as keyword-arguments to the function. Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`. """ - await self._map(func, kwargs_iter, arg_stars=2, group_size=group_size, + if group_name is None: + group_name = self._generate_group_name('doublestarmap', func) + await self._map(group_name, group_size, func, kwargs_iter, 2, end_callback=end_callback, cancel_callback=cancel_callback) + return group_name class SimpleTaskPool(BaseTaskPool): @@ -729,7 +933,7 @@ class SimpleTaskPool(BaseTaskPool): """ def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, + end_callback: EndCB = None, cancel_callback: CancelCB = None, pool_size: int = inf, name: str = None) -> None: """ @@ -756,8 +960,8 @@ class SimpleTaskPool(BaseTaskPool): self._func: CoroutineFunc = func self._args: ArgsT = args self._kwargs: KwArgsT = kwargs if kwargs is not None else {} - self._end_callback: EndCallbackT = end_callback - self._cancel_callback: CancelCallbackT = cancel_callback + self._end_callback: EndCB = end_callback + self._cancel_callback: CancelCB = cancel_callback super().__init__(pool_size=pool_size, name=name) @property @@ -784,7 +988,7 @@ class SimpleTaskPool(BaseTaskPool): If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled. """ ids = [] - for i, task_id in enumerate(reversed(self._running)): + for i, task_id in enumerate(reversed(self._tasks_running)): if i >= num: break # We got the desired number of task IDs, there may well be more tasks left to keep running ids.append(task_id) diff --git a/src/asyncio_taskpool/queue_context.py b/src/asyncio_taskpool/queue_context.py new file mode 100644 index 0000000..29959bf --- /dev/null +++ b/src/asyncio_taskpool/queue_context.py @@ -0,0 +1,58 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +This module contains the definition of an `asyncio.Queue` subclass. +""" + + +from asyncio.queues import Queue as _Queue +from typing import Any + + +class Queue(_Queue): + """This just adds a little syntactic sugar to the `asyncio.Queue`.""" + + def item_processed(self) -> None: + """ + Does exactly the same as `task_done()`. + + This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing, + invites confusion, and immensely reduces readability (in the context of this library). And readability counts. + """ + self.task_done() + + async def __aenter__(self) -> Any: + """ + Implements an asynchronous context manager for the queue. + + Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned. + It allows writing code this way: + >>> queue = Queue() + >>> ... + >>> async with queue as item: + >>> ... + """ + return await self.get() + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Implements an asynchronous context manager for the queue. + + Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want, + but in some situations it makes the codes much cleaner. + """ + self.item_processed() diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index f6177e7..b8a5194 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -32,8 +32,8 @@ KwArgsT = Mapping[str, Any] AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] CoroutineFunc = Callable[[...], Awaitable[Any]] -EndCallbackT = Callable -CancelCallbackT = Callable +EndCB = Callable +CancelCB = Callable ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] diff --git a/tests/test_pool.py b/tests/test_pool.py index bbbc64f..aef5fc0 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -18,19 +18,20 @@ __doc__ = """ Unittests for the `asyncio_taskpool.pool` module. """ - -import asyncio from asyncio.exceptions import CancelledError -from asyncio.queues import Queue +from asyncio.locks import Semaphore +from asyncio.queues import QueueEmpty +from datetime import datetime from unittest import IsolatedAsyncioTestCase from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from typing import Type from asyncio_taskpool import pool, exceptions +from asyncio_taskpool.constants import DATETIME_FORMAT -EMPTY_LIST, EMPTY_DICT = [], {} -FOO, BAR = 'foo', 'bar' +EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set() +FOO, BAR, BAZ = 'foo', 'bar', 'baz' class TestException(Exception): @@ -45,19 +46,12 @@ class CommonTestCase(IsolatedAsyncioTestCase): task_pool: pool.BaseTaskPool log_lvl: int - @classmethod - def setUpClass(cls) -> None: - cls.log_lvl = pool.log.level - pool.log.setLevel(999) - - @classmethod - def tearDownClass(cls) -> None: - pool.log.setLevel(cls.log_lvl) - def get_task_pool_init_params(self) -> dict: return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME} def setUp(self) -> None: + self.log_lvl = pool.log.level + pool.log.setLevel(999) self._pools = self.TEST_CLASS._pools # These three methods are called during initialization, so we mock them by default during setup: self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool') @@ -76,6 +70,7 @@ class CommonTestCase(IsolatedAsyncioTestCase): self._add_pool_patcher.stop() self.pool_size_patcher.stop() self.dunder_str_patcher.stop() + pool.log.setLevel(self.log_lvl) class BaseTaskPoolTestCase(CommonTestCase): @@ -88,19 +83,23 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools) def test_init(self): - self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) + self.assertEqual(0, self.task_pool._num_started) + self.assertEqual(0, self.task_pool._num_cancellations) + self.assertFalse(self.task_pool._locked) - self.assertEqual(0, self.task_pool._counter) - self.assertDictEqual(EMPTY_DICT, self.task_pool._running) - self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) - self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) - self.assertEqual(0, self.task_pool._num_cancelled) - self.assertEqual(0, self.task_pool._num_ended) - self.assertEqual(self.mock_idx, self.task_pool._idx) + self.assertFalse(self.task_pool._closed) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) + + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) + self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) - self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) - self.assertFalse(self.task_pool._interrupt_flag.is_set()) + self.assertIsInstance(self.task_pool._enough_room, Semaphore) + self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) + + self.assertEqual(self.mock_idx, self.task_pool._idx) + self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE) self.mock___str__.assert_called_once_with() @@ -143,26 +142,56 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertFalse(self.task_pool._locked) def test_num_running(self): - self.task_pool._running = ['foo', 'bar', 'baz'] + self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ} self.assertEqual(3, self.task_pool.num_running) - def test_num_cancelled(self): - self.task_pool._num_cancelled = 3 - self.assertEqual(3, self.task_pool.num_cancelled) + def test_num_cancellations(self): + self.task_pool._num_cancellations = 3 + self.assertEqual(3, self.task_pool.num_cancellations) def test_num_ended(self): - self.task_pool._num_ended = 3 + self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ} self.assertEqual(3, self.task_pool.num_ended) def test_num_finished(self): - self.task_pool._num_cancelled = cancelled = 69 - self.task_pool._num_ended = ended = 420 - self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'} - self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished) + self.task_pool._num_cancellations = num_cancellations = 69 + num_ended = 420 + self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)} + self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ} + self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished) def test_is_full(self): self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) + def test_get_task_group_ids(self): + group_name, ids = 'abcdef', [1, 2, 3] + self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids)) + self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name)) + with self.assertRaises(exceptions.InvalidGroupName): + self.task_pool.get_task_group_ids('something else') + + async def test__check_start(self): + self.task_pool._closed = True + mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() + try: + with self.assertRaises(AssertionError): + self.task_pool._check_start(awaitable=None, function=None) + with self.assertRaises(AssertionError): + self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function) + with self.assertRaises(exceptions.NotCoroutine): + self.task_pool._check_start(awaitable=mock_coroutine_function, function=None) + with self.assertRaises(exceptions.NotCoroutine): + self.task_pool._check_start(awaitable=None, function=mock_coroutine) + with self.assertRaises(exceptions.PoolIsClosed): + self.task_pool._check_start(awaitable=mock_coroutine, function=None) + self.task_pool._closed = False + self.task_pool._locked = True + with self.assertRaises(exceptions.PoolIsLocked): + self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) + self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True)) + finally: + await mock_coroutine + def test__task_name(self): i = 123 self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) @@ -171,12 +200,12 @@ class BaseTaskPoolTestCase(CommonTestCase): @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() - self.task_pool._num_cancelled = cancelled = 3 - self.task_pool._running[task_id] = mock_task + self.task_pool._num_cancellations = cancelled = 3 + self.task_pool._tasks_running[task_id] = mock_task self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback)) - self.assertNotIn(task_id, self.task_pool._running) - self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) - self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) + self.assertNotIn(task_id, self.task_pool._tasks_running) + self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id]) + self.assertEqual(cancelled + 1, self.task_pool._num_cancellations) mock__task_name.assert_called_with(task_id) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) @@ -184,15 +213,13 @@ class BaseTaskPoolTestCase(CommonTestCase): @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() - self.task_pool._num_ended = ended = 3 self.task_pool._enough_room._value = room = 123 # End running task: - self.task_pool._running[task_id] = mock_task + self.task_pool._tasks_running[task_id] = mock_task self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) - self.assertNotIn(task_id, self.task_pool._running) - self.assertEqual(mock_task, self.task_pool._ended[task_id]) - self.assertEqual(ended + 1, self.task_pool._num_ended) + self.assertNotIn(task_id, self.task_pool._tasks_running) + self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id]) self.assertEqual(room + 1, self.task_pool._enough_room._value) mock__task_name.assert_called_with(task_id) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) @@ -200,11 +227,10 @@ class BaseTaskPoolTestCase(CommonTestCase): mock_execute_optional.reset_mock() # End cancelled task: - self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) + self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id) self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) - self.assertNotIn(task_id, self.task_pool._cancelled) - self.assertEqual(mock_task, self.task_pool._ended[task_id]) - self.assertEqual(ended + 2, self.task_pool._num_ended) + self.assertNotIn(task_id, self.task_pool._tasks_cancelled) + self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id]) self.assertEqual(room + 2, self.task_pool._enough_room._value) mock__task_name.assert_called_with(task_id) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) @@ -246,92 +272,52 @@ class BaseTaskPoolTestCase(CommonTestCase): @patch.object(pool, 'create_task') @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) - async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock, - mock_create_task: MagicMock): - def reset_mocks() -> None: - mock__task_name.reset_mock() - mock__task_wrapper.reset_mock() - mock_create_task.reset_mock() - + @patch.object(pool, 'TaskGroupRegister') + @patch.object(pool.BaseTaskPool, '_check_start') + async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock, + mock__task_wrapper: AsyncMock, mock_create_task: MagicMock): + mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_create_task.return_value = mock_task = MagicMock() mock__task_wrapper.return_value = mock_wrapped = MagicMock() - mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock() - self.task_pool._counter = count = 123 + mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() + self.task_pool._num_started = count = 123 self.task_pool._enough_room._value = room = 123 - - def check_nothing_changed() -> None: - self.assertEqual(count, self.task_pool._counter) - self.assertNotIn(count, self.task_pool._running) - self.assertEqual(room, self.task_pool._enough_room._value) - mock__task_name.assert_not_called() - mock__task_wrapper.assert_not_called() - mock_create_task.assert_not_called() - reset_mocks() - - with self.assertRaises(exceptions.NotCoroutine): - await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) - check_nothing_changed() - - self.task_pool._locked = True - ignore_closed = False - mock_awaitable = mock_coroutine() - with self.assertRaises(exceptions.PoolIsLocked): - await self.task_pool._start_task(mock_awaitable, ignore_closed, - end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) - await mock_awaitable - check_nothing_changed() - - ignore_closed = True - mock_awaitable = mock_coroutine() - output = await self.task_pool._start_task(mock_awaitable, ignore_closed, + group_name, ignore_lock = 'testgroup', True + output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock, end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) - await mock_awaitable self.assertEqual(count, output) - self.assertEqual(count + 1, self.task_pool._counter) - self.assertEqual(mock_task, self.task_pool._running[count]) + mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock) self.assertEqual(room - 1, self.task_pool._enough_room._value) + self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name]) + mock_reg_cls.assert_called_once_with() + mock_group_reg.__aenter__.assert_awaited_once_with() + mock_group_reg.add.assert_called_once_with(count) mock__task_name.assert_called_once_with(count) - mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) - mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) - reset_mocks() - self.task_pool._counter = count - self.task_pool._enough_room._value = room - del self.task_pool._running[count] - - mock_awaitable = mock_coroutine() - mock_create_task.side_effect = test_exception = TestException() - with self.assertRaises(TestException) as e: - await self.task_pool._start_task(mock_awaitable, ignore_closed, - end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) - self.assertEqual(test_exception, e) - await mock_awaitable - self.assertEqual(count + 1, self.task_pool._counter) - self.assertNotIn(count, self.task_pool._running) - self.assertEqual(room, self.task_pool._enough_room._value) - mock__task_name.assert_called_once_with(count) - mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) - mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) + mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb) + mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO) + self.assertEqual(mock_task, self.task_pool._tasks_running[count]) + mock_group_reg.__aexit__.assert_awaited_once() @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) def test__get_running_task(self, mock__task_name: MagicMock): task_id, mock_task = 555, MagicMock() - self.task_pool._running[task_id] = mock_task + self.task_pool._tasks_running[task_id] = mock_task output = self.task_pool._get_running_task(task_id) self.assertEqual(mock_task, output) - self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id) + self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id) with self.assertRaises(exceptions.AlreadyCancelled): self.task_pool._get_running_task(task_id) mock__task_name.assert_called_once_with(task_id) mock__task_name.reset_mock() - self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id) + self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id) with self.assertRaises(exceptions.TaskEnded): self.task_pool._get_running_task(task_id) mock__task_name.assert_called_once_with(task_id) mock__task_name.reset_mock() - del self.task_pool._ended[task_id] + del self.task_pool._tasks_ended[task_id] with self.assertRaises(exceptions.InvalidTaskID): self.task_pool._get_running_task(task_id) mock__task_name.assert_not_called() @@ -344,263 +330,416 @@ class BaseTaskPoolTestCase(CommonTestCase): mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) - def test_cancel_all(self): - mock_task1, mock_task2 = MagicMock(), MagicMock() - self.task_pool._running = {1: mock_task1, 2: mock_task2} - assert not self.task_pool._interrupt_flag.is_set() - self.assertIsNone(self.task_pool.cancel_all(FOO)) - self.assertTrue(self.task_pool._interrupt_flag.is_set()) - mock_task1.cancel.assert_called_once_with(msg=FOO) - mock_task2.cancel.assert_called_once_with(msg=FOO) + def test__cancel_and_remove_all_from_group(self): + task_id = 555 + mock_cancel = MagicMock() + self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel) + + class MockRegister(set, MagicMock): + pass + self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO)) + mock_cancel.assert_called_once_with(msg=FOO) + + @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') + async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): + mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock() + mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit) + self.task_pool._task_groups[FOO] = mock_group_reg + with self.assertRaises(exceptions.InvalidGroupName): + await self.task_pool.cancel_group(BAR) + mock__cancel_and_remove_all_from_group.assert_not_called() + mock_grp_aenter.assert_not_called() + mock_grp_aexit.assert_not_called() + self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR)) + mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) + mock_grp_aenter.assert_awaited_once_with() + mock_grp_aexit.assert_awaited_once() + + @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') + async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): + mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock() + mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit) + self.task_pool._task_groups[BAR] = mock_group_reg + self.assertIsNone(await self.task_pool.cancel_all(FOO)) + mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO) + mock_grp_aenter.assert_awaited_once_with() + mock_grp_aexit.assert_awaited_once() async def test_flush(self): - test_exception = TestException() - mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) - self.task_pool._ended = {123: mock_ended_func()} - self.task_pool._cancelled = {456: mock_cancelled_func()} - self.task_pool._interrupt_flag.set() - output = await self.task_pool.flush(return_exceptions=True) - self.assertListEqual([FOO, test_exception], output) - self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) - self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) - self.assertFalse(self.task_pool._interrupt_flag.is_set()) + mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) + self.task_pool._tasks_ended = {123: mock_ended_func()} + self.task_pool._tasks_cancelled = {456: mock_cancelled_func()} + self.assertIsNone(await self.task_pool.flush(return_exceptions=True)) + mock_ended_func.assert_awaited_once_with() + mock_cancelled_func.assert_awaited_once_with() + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) - self.task_pool._ended = {123: mock_ended_func()} - self.task_pool._cancelled = {456: mock_cancelled_func()} - output = await self.task_pool.flush(return_exceptions=True) - self.assertListEqual([FOO, test_exception], output) - self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) - self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) + async def test_gather_and_close(self): + mock_before_gather, mock_running_func = AsyncMock(), AsyncMock() + mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) + self.task_pool._before_gathering = before_gather = [mock_before_gather()] + self.task_pool._tasks_ended = ended = {123: mock_ended_func()} + self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()} + self.task_pool._tasks_running = running = {789: mock_running_func()} - async def test_gather(self): - test_exception = TestException() - mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) - mock_running_func = AsyncMock(return_value=BAR) - mock_queue_join = AsyncMock() - self.task_pool._before_gathering = before_gather = [mock_queue_join()] - self.task_pool._ended = ended = {123: mock_ended_func()} - self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()} - self.task_pool._running = running = {789: mock_running_func()} - self.task_pool._interrupt_flag.set() - - assert not self.task_pool._locked with self.assertRaises(exceptions.PoolStillUnlocked): - await self.task_pool.gather() - self.assertDictEqual(self.task_pool._ended, ended) - self.assertDictEqual(self.task_pool._cancelled, cancelled) - self.assertDictEqual(self.task_pool._running, running) - self.assertListEqual(self.task_pool._before_gathering, before_gather) - self.assertTrue(self.task_pool._interrupt_flag.is_set()) + await self.task_pool.gather_and_close() + self.assertDictEqual(ended, self.task_pool._tasks_ended) + self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled) + self.assertDictEqual(running, self.task_pool._tasks_running) + self.assertListEqual(before_gather, self.task_pool._before_gathering) + self.assertFalse(self.task_pool._closed) self.task_pool._locked = True - - def check_assertions(output) -> None: - self.assertListEqual([FOO, test_exception, BAR], output) - self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) - self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) - self.assertDictEqual(self.task_pool._running, EMPTY_DICT) - self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) - self.assertFalse(self.task_pool._interrupt_flag.is_set()) - - check_assertions(await self.task_pool.gather(return_exceptions=True)) - - self.task_pool._before_gathering = [mock_queue_join()] - self.task_pool._ended = {123: mock_ended_func()} - self.task_pool._cancelled = {456: mock_cancelled_func()} - self.task_pool._running = {789: mock_running_func()} - check_assertions(await self.task_pool.gather(return_exceptions=True)) + self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) + mock_before_gather.assert_awaited_once_with() + mock_ended_func.assert_awaited_once_with() + mock_cancelled_func.assert_awaited_once_with() + mock_running_func.assert_awaited_once_with() + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) + self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) + self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering) + self.assertTrue(self.task_pool._closed) class TaskPoolTestCase(CommonTestCase): TEST_CLASS = pool.TaskPool task_pool: pool.TaskPool - @patch.object(pool.TaskPool, '_start_task') - async def test__apply_one(self, mock__start_task: AsyncMock): - mock__start_task.return_value = expected_output = 12345 - mock_awaitable = MagicMock() - mock_func = MagicMock(return_value=mock_awaitable) - args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} - end_cb, cancel_cb = MagicMock(), MagicMock() - output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb) + def setUp(self) -> None: + self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__') + self.base_class_init = self.base_class_init_patcher.start() + super().setUp() + + def tearDown(self) -> None: + self.base_class_init_patcher.stop() + super().tearDown() + + def test_init(self): + self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) + self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME) + + def test__cancel_group_meta_tasks(self): + mock_task1, mock_task2 = MagicMock(), MagicMock() + self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2} + self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO)) + self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running) + self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) + mock_task1.cancel.assert_not_called() + mock_task2.cancel.assert_not_called() + + self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR)) + self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) + self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled) + mock_task1.cancel.assert_called_once_with() + mock_task2.cancel.assert_called_once_with() + + @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') + @patch.object(pool.TaskPool, '_cancel_group_meta_tasks') + def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock, + mock_base__cancel_and_remove_all_from_group: MagicMock): + group_name, group_reg, msg = 'xyz', MagicMock(), FOO + self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)) + mock__cancel_group_meta_tasks.assert_called_once_with(group_name) + mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg) + + @patch.object(pool.BaseTaskPool, 'cancel_group') + async def test_cancel_group(self, mock_base_cancel_group: AsyncMock): + group_name, msg = 'abc', 'xyz' + await self.task_pool.cancel_group(group_name, msg=msg) + mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg) + + @patch.object(pool.BaseTaskPool, 'cancel_all') + async def test_cancel_all(self, mock_base_cancel_all: AsyncMock): + msg = 'xyz' + await self.task_pool.cancel_all(msg=msg) + mock_base_cancel_all.assert_awaited_once_with(msg=msg) + + def test__pop_ended_meta_tasks(self): + mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True) + self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1} + mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True) + self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3} + expected_output = {mock_done_task1, mock_done_task2, mock_done_task3} + output = self.task_pool._pop_ended_meta_tasks() + self.assertSetEqual(expected_output, output) + self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running) + + @patch.object(pool.TaskPool, '_pop_ended_meta_tasks') + @patch.object(pool.BaseTaskPool, 'flush') + async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock): + mock_ended_meta_task = AsyncMock() + mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()} + mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) + self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} + self.assertIsNone(await self.task_pool.flush(return_exceptions=False)) + mock_base_flush.assert_awaited_once_with(return_exceptions=False) + mock__pop_ended_meta_tasks.assert_called_once_with() + mock_ended_meta_task.assert_awaited_once_with() + mock_cancelled_meta_task.assert_awaited_once_with() + self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) + + @patch.object(pool.BaseTaskPool, 'gather_and_close') + async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock): + mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock() + self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}} + mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) + self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} + self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) + mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True) + mock_meta_task1.assert_awaited_once_with() + mock_meta_task2.assert_awaited_once_with() + mock_cancelled_meta_task.assert_awaited_once_with() + self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) + self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) + + @patch.object(pool, 'datetime') + def test__generate_group_name(self, mock_datetime: MagicMock): + prefix, func = 'x y z', AsyncMock(__name__=BAR) + dt = datetime(1776, 7, 4, 0, 0, 1) + mock_datetime.now = MagicMock(return_value=dt) + expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}' + output = pool.TaskPool._generate_group_name(prefix, func) self.assertEqual(expected_output, output) - mock_func.assert_called_once_with(*args, **kwargs) - mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) + + @patch.object(pool.TaskPool, '_start_task') + async def test__apply_num(self, mock__start_task: AsyncMock): + group_name = FOO + BAR + mock_awaitable = object() + mock_func = MagicMock(return_value=mock_awaitable) + args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)) + mock_func.assert_has_calls(3 * [call(*args, **kwargs)]) + mock__start_task.assert_has_awaits(3 * [ + call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) + ]) mock_func.reset_mock() mock__start_task.reset_mock() - output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb) - self.assertEqual(expected_output, output) - mock_func.assert_called_once_with(*args) - mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) + self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb)) + mock_func.assert_has_calls(num * [call(*args)]) + mock__start_task.assert_has_awaits(num * [ + call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) + ]) - @patch.object(pool.TaskPool, '_apply_one') - async def test_apply(self, mock__apply_one: AsyncMock): - mock__apply_one.return_value = mock_id = 67890 - mock_func, num = MagicMock(), 3 + @patch.object(pool, 'create_task') + @patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock()) + @patch.object(pool, 'TaskGroupRegister') + @patch.object(pool.TaskPool, '_generate_group_name') + @patch.object(pool.BaseTaskPool, '_check_start') + async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, + mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): + mock__generate_group_name.return_value = generated_name = 'name 123' + mock_group_reg = set_up_mock_group_register(mock_reg_cls) + mock__apply_num.return_value = mock_apply_coroutine = object() + mock_task_future = AsyncMock() + mock_create_task.return_value = mock_task_future() + mock_func, num, group_name = MagicMock(), 3, FOO + BAR args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} end_cb, cancel_cb = MagicMock(), MagicMock() - expected_output = num * [mock_id] - output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb) - self.assertEqual(expected_output, output) - mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)]) + self.task_pool._task_groups = {} - async def test__queue_producer(self): + def check_assertions(_group_name, _output): + self.assertEqual(_group_name, _output) + mock__check_start.assert_called_once_with(function=mock_func) + self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) + mock_group_reg.__aenter__.assert_awaited_once_with() + mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb) + mock_create_task.assert_called_once_with(mock_apply_coroutine) + mock_group_reg.__aexit__.assert_awaited_once() + mock_task_future.assert_awaited_once_with() + + output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) + check_assertions(group_name, output) + mock__generate_group_name.assert_not_called() + + mock__check_start.reset_mock() + self.task_pool._task_groups.clear() + mock_group_reg.__aenter__.reset_mock() + mock__apply_num.reset_mock() + mock_create_task.reset_mock() + mock_group_reg.__aexit__.reset_mock() + mock_task_future = AsyncMock() + mock_create_task.return_value = mock_task_future() + + output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) + check_assertions(generated_name, output) + mock__generate_group_name.assert_called_once_with('apply', mock_func) + + @patch.object(pool, 'Queue') + async def test__queue_producer(self, mock_queue_cls: MagicMock): mock_put = AsyncMock() - mock_q = MagicMock(put=mock_put) - args = (FOO, BAR, 123) - assert not self.task_pool._interrupt_flag.is_set() - self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) - mock_put.assert_has_awaits([call(arg) for arg in args]) + mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put) + item1, item2, item3 = FOO, 420, 69 + arg_iter = iter([item1, item2, item3]) + self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) + mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)]) + with self.assertRaises(StopIteration): + next(arg_iter) + mock_put.reset_mock() - self.task_pool._interrupt_flag.set() - self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) - mock_put.assert_not_awaited() - @patch.object(pool, 'partial') - @patch.object(pool, 'star_function') - @patch.object(pool.TaskPool, '_start_task') - async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock, - mock_partial: MagicMock): - mock_partial.return_value = queue_callback = 'not really' - mock_star_function.return_value = awaitable = 'totally an awaitable' - q, arg = Queue(), 420.69 - q.put_nowait(arg) - mock_func, stars = MagicMock(), 3 - mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) - self.assertTrue(q.empty()) - mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True, - end_callback=queue_callback, cancel_callback=cancel_cb) - mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars) - mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool, - q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars, - end_callback=end_cb, cancel_callback=cancel_cb) - mock__start_task.reset_mock() - mock_star_function.reset_mock() - mock_partial.reset_mock() - - self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) - self.assertTrue(q.empty()) - mock__start_task.assert_not_awaited() - mock_star_function.assert_not_called() - mock_partial.assert_not_called() + mock_put.side_effect = [CancelledError, None] + arg_iter = iter([item1, item2, item3]) + mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty] + self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) + mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)]) + mock_queue.get_nowait.assert_has_calls([call(), call(), call()]) + mock_queue.item_processed.assert_has_calls([call(), call()]) + self.assertListEqual([item2, item3], list(arg_iter)) @patch.object(pool, 'execute_optional') - @patch.object(pool.TaskPool, '_queue_consumer') - async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock): - task_id, mock_q = 420, MagicMock() - mock_func, stars = MagicMock(), 3 - mock_wait = AsyncMock() - mock_flag = MagicMock(wait=mock_wait) - end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars, - end_callback=end_cb, cancel_callback=cancel_cb)) - mock_wait.assert_awaited_once_with() - mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars, - end_callback=end_cb, cancel_callback=cancel_cb) - mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,)) + async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): + semaphore, mock_end_cb = Semaphore(1), MagicMock() + wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb) + task_id = 1234 + await wrapped(task_id) + self.assertEqual(2, semaphore._value) + mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,)) + + @patch.object(pool, 'star_function') + @patch.object(pool.TaskPool, '_start_task') + @patch.object(pool, 'Semaphore') + @patch.object(pool.TaskPool, '_get_map_end_callback') + async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock, + mock__start_task: AsyncMock, mock_star_function: MagicMock): + mock__get_map_end_callback.return_value = map_cb = MagicMock() + mock_semaphore_cls.return_value = semaphore = Semaphore(3) + mock_star_function.return_value = awaitable = 'totally an awaitable' + arg1, arg2 = 123456789, 'function argument' + mock_q_maxsize = 3 + mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]), + __aexit__=AsyncMock(), maxsize=mock_q_maxsize) + group_name, mock_func, stars = 'whatever', MagicMock(), 3 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) + self.assertTrue(semaphore.locked()) + mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) + mock__start_task.assert_has_awaits(2 * [ + call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) + ]) + mock_star_function.assert_has_calls([ + call(mock_func, arg1, arg_stars=stars), + call(mock_func, arg2, arg_stars=stars) + ]) - @patch.object(pool, 'iter') @patch.object(pool, 'create_task') - @patch.object(pool, 'join_queue', new_callable=MagicMock) + @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) - async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock, - mock_create_task: MagicMock, mock_iter: MagicMock): - args, num_tasks = (FOO, BAR, 1, 2, 3), 2 - mock_join_queue.return_value = mock_join = 'awaitable' - mock_iter.return_value = args_iter = iter(args) - mock__queue_producer.return_value = mock_producer_coro = 'very awaitable' - output_q = self.task_pool._set_up_args_queue(args, num_tasks) - self.assertIsInstance(output_q, Queue) - self.assertEqual(num_tasks, output_q.qsize()) - for arg in args[:num_tasks]: - self.assertEqual(arg, output_q.get_nowait()) - self.assertTrue(output_q.empty()) - for arg in args[num_tasks:]: - self.assertEqual(arg, next(args_iter)) - with self.assertRaises(StopIteration): - next(args_iter) - self.assertListEqual([mock_join], self.task_pool._before_gathering) - mock_join_queue.assert_called_once_with(output_q) - mock__queue_producer.assert_called_once_with(output_q, args_iter) - mock_create_task.assert_called_once_with(mock_producer_coro) + @patch.object(pool, 'join_queue', new_callable=MagicMock) + @patch.object(pool, 'Queue') + @patch.object(pool, 'TaskGroupRegister') + @patch.object(pool.BaseTaskPool, '_check_start') + async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, + mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, + mock_create_task: MagicMock): + mock_group_reg = set_up_mock_group_register(mock_reg_cls) + mock_queue_cls.return_value = mock_q = MagicMock() + mock_join_queue.return_value = fake_join = object() + mock__queue_producer.return_value = fake_producer = object() + mock__queue_consumer.return_value = fake_consumer = object() + fake_task1, fake_task2 = object(), object() + mock_create_task.side_effect = [fake_task1, fake_task2] - self.task_pool._before_gathering.clear() - mock_join_queue.reset_mock() - mock__queue_producer.reset_mock() - mock_create_task.reset_mock() - - num_tasks = 6 - mock_iter.return_value = args_iter = iter(args) - output_q = self.task_pool._set_up_args_queue(args, num_tasks) - self.assertIsInstance(output_q, Queue) - self.assertEqual(len(args), output_q.qsize()) - for arg in args: - self.assertEqual(arg, output_q.get_nowait()) - self.assertTrue(output_q.empty()) - with self.assertRaises(StopIteration): - next(args_iter) - self.assertListEqual([mock_join], self.task_pool._before_gathering) - mock_join_queue.assert_called_once_with(output_q) - mock__queue_producer.assert_not_called() - mock_create_task.assert_not_called() - - @patch.object(pool, 'Event') - @patch.object(pool.TaskPool, '_queue_consumer') - @patch.object(pool.TaskPool, '_set_up_args_queue') - async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock, - mock_event_cls: MagicMock): - qsize = 4 - mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize)) - mock_flag_set = MagicMock() - mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set) - - mock_func, stars = MagicMock(), 3 - args_iter, group_size = (FOO, BAR, 1, 2, 3), 2 + group_name, group_size = 'onetwothree', 0 + func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 end_cb, cancel_cb = MagicMock(), MagicMock() - self.task_pool._locked = False - with self.assertRaises(exceptions.PoolIsLocked): - await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb) - mock__set_up_args_queue.assert_not_called() - mock__queue_consumer.assert_not_awaited() - mock_flag_set.assert_not_called() + with self.assertRaises(ValueError): + await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) + mock__check_start.assert_called_once_with(function=func) - self.task_pool._locked = True - self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)) - mock__set_up_args_queue.assert_called_once_with(args_iter, group_size) - mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars, - end_callback=end_cb, cancel_callback=cancel_cb)]) - mock_flag_set.assert_called_once_with() + mock__check_start.reset_mock() + + group_size = 1234 + self.task_pool._task_groups = {group_name: MagicMock()} + + with self.assertRaises(exceptions.InvalidGroupName): + await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) + mock__check_start.assert_called_once_with(function=func) + + mock__check_start.reset_mock() + + self.task_pool._task_groups.clear() + self.task_pool._before_gathering = [] + + self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)) + mock__check_start.assert_called_once_with(function=func) + mock_reg_cls.assert_called_once_with() + self.task_pool._task_groups[group_name] = mock_group_reg + mock_group_reg.__aenter__.assert_awaited_once_with() + mock_queue_cls.assert_called_once_with(maxsize=group_size) + mock_join_queue.assert_called_once_with(mock_q) + self.assertListEqual([fake_join], self.task_pool._before_gathering) + mock__queue_producer.assert_called_once() + mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb) + mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)]) + self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name]) + mock_group_reg.__aexit__.assert_awaited_once() @patch.object(pool.TaskPool, '_map') - async def test_map(self, mock__map: AsyncMock): + @patch.object(pool.TaskPool, '_generate_group_name') + async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() - arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2 + arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb)) - mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size, + output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb) + self.assertEqual(group_name, output) + mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0, end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_not_called() + + mock__map.reset_mock() + output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb) + self.assertEqual(generated_name, output) + mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0, + end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_called_once_with('map', mock_func) @patch.object(pool.TaskPool, '_map') - async def test_starmap(self, mock__map: AsyncMock): + @patch.object(pool.TaskPool, '_generate_group_name') + async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() - args_iter, group_size = ([FOO], [BAR]), 2 + args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb)) - mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size, + output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb) + self.assertEqual(group_name, output) + mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1, end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_not_called() + + mock__map.reset_mock() + output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb) + self.assertEqual(generated_name, output) + mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1, + end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_called_once_with('starmap', mock_func) @patch.object(pool.TaskPool, '_map') - async def test_doublestarmap(self, mock__map: AsyncMock): + @patch.object(pool.TaskPool, '_generate_group_name') + async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() - kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2 + kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb)) - mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size, + output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb) + self.assertEqual(group_name, output) + mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2, end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_not_called() + + mock__map.reset_mock() + output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb) + self.assertEqual(generated_name, output) + mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2, + end_callback=end_cb, cancel_callback=cancel_cb) + mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func) class SimpleTaskPoolTestCase(CommonTestCase): @@ -667,7 +806,7 @@ class SimpleTaskPoolTestCase(CommonTestCase): def test_stop(self, mock_cancel: MagicMock): num = 2 id1, id2, id3 = 5, 6, 7 - self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR} + self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR} output = self.task_pool.stop(num) expected_output = [id3, id2] self.assertEqual(expected_output, output) @@ -689,3 +828,10 @@ class SimpleTaskPoolTestCase(CommonTestCase): self.assertEqual(expected_output, output) mock_num_running.assert_called_once_with() mock_stop.assert_called_once_with(num) + + +def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock: + mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock() + mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit, + add=mock_grp_add) + return mock_group_reg diff --git a/usage/USAGE.md b/usage/USAGE.md index 33272b4..741bbb1 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -28,7 +28,7 @@ async def work(n: int) -> None: """ for i in range(n): await asyncio.sleep(1) - print("did", i) + print("> did", i) async def main() -> None: @@ -39,7 +39,7 @@ async def main() -> None: await asyncio.sleep(1.5) # lets the tasks work for a bit pool.stop(2) # cancels tasks 3 and 2 pool.lock() # required for the last line - await pool.gather() # awaits all tasks, then flushes the pool + await pool.gather_and_close() # awaits all tasks, then flushes the pool if __name__ == '__main__': @@ -52,29 +52,29 @@ SimpleTaskPool-0 initialized Started SimpleTaskPool-0_Task-0 Started SimpleTaskPool-0_Task-1 Started SimpleTaskPool-0_Task-2 -did 0 -did 0 -did 0 +> did 0 +> did 0 +> did 0 Started SimpleTaskPool-0_Task-3 -did 1 -did 1 -did 1 -did 0 +> did 1 +> did 1 +> did 1 +> did 0 +> did 2 +> did 2 SimpleTaskPool-0 is locked! -Cancelling SimpleTaskPool-0_Task-3 ... -Cancelled SimpleTaskPool-0_Task-3 -Ended SimpleTaskPool-0_Task-3 Cancelling SimpleTaskPool-0_Task-2 ... Cancelled SimpleTaskPool-0_Task-2 Ended SimpleTaskPool-0_Task-2 -did 2 -did 2 -did 3 -did 3 +Cancelling SimpleTaskPool-0_Task-3 ... +Cancelled SimpleTaskPool-0_Task-3 +Ended SimpleTaskPool-0_Task-3 +> did 3 +> did 3 Ended SimpleTaskPool-0_Task-0 Ended SimpleTaskPool-0_Task-1 -did 4 -did 4 +> did 4 +> did 4 ``` ## Advanced example for `TaskPool` @@ -101,21 +101,21 @@ async def work(start: int, stop: int, step: int = 1) -> None: """Pseudo-worker function counting through a range with a second of sleep in between each iteration.""" for i in range(start, stop, step): await asyncio.sleep(1) - print("work with", i) + print("> work with", i) async def other_work(a: int, b: int) -> None: """Different pseudo-worker counting through a range with half a second of sleep in between each iteration.""" for i in range(a, b): await asyncio.sleep(0.5) - print("other_work with", i) + print("> other_work with", i) async def main() -> None: # Initialize a new task pool instance and limit its size to 3 tasks. pool = TaskPool(3) # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments). - print("Called `apply`") + print("> Called `apply`") await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) # Let the tasks work for a bit. await asyncio.sleep(1.5) @@ -124,20 +124,18 @@ async def main() -> None: # Since we set our pool size to 3, and already have two tasks working within the pool, # only the first one of these will start immediately (and receive ID 2). # The second one will start (with ID 3), only once there is room in the pool, - # which -- in this example -- will be the case after ID 2 ends; - # until then the `starmap` method call **will block**! + # which -- in this example -- will be the case after ID 2 ends. # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) - # **only** once there is room in the pool **and** no more than one of these last four tasks is running. + # **only** once there is room in the pool **and** no more than one other task of these new ones is running. args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] - print("Calling `starmap`...") await pool.starmap(other_work, args_list, group_size=2) - print("`starmap` returned") + print("> Called `starmap`") # Now we lock the pool, so that we can safely await all our tasks. pool.lock() # Finally, we block, until all tasks have ended. - print("Called `gather`") - await pool.gather() - print("Done.") + print("> Calling `gather_and_close`...") + await pool.gather_and_close() + print("> Done.") if __name__ == '__main__': @@ -152,82 +150,81 @@ Additional comments for the output are provided with `<---` next to the output l TaskPool-0 initialized Started TaskPool-0_Task-0 Started TaskPool-0_Task-1 -Called `apply` -work with 100 -work with 100 -Calling `starmap`... <--- notice that this blocks as expected -Started TaskPool-0_Task-2 -work with 110 -work with 110 -other_work with 0 -other_work with 1 -work with 120 -work with 120 -other_work with 2 -other_work with 3 -work with 130 -work with 130 -other_work with 4 -other_work with 5 -work with 140 -work with 140 -other_work with 6 -other_work with 7 -work with 150 -work with 150 -other_work with 8 -Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()` +> Called `apply` +> work with 100 +> work with 100 +> Called `starmap` <--- notice that this immediately returns, even before Task-2 is started +> Calling `gather_and_close`... <--- this blocks `main()` until all tasks have ended TaskPool-0 is locked! +Started TaskPool-0_Task-2 <--- at this point the pool is full +> work with 110 +> work with 110 +> other_work with 0 +> other_work with 1 +> work with 120 +> work with 120 +> other_work with 2 +> other_work with 3 +> work with 130 +> work with 130 +> other_work with 4 +> other_work with 5 +> work with 140 +> work with 140 +> other_work with 6 +> other_work with 7 +> work with 150 +> work with 150 +> other_work with 8 +Ended TaskPool-0_Task-2 <--- this frees up room for one more task from `starmap` Started TaskPool-0_Task-3 -other_work with 9 -`starmap` returned -Called `gather` <--- now this will block `main()` until all tasks have ended -work with 160 -work with 160 -other_work with 10 -other_work with 11 -work with 170 -work with 170 -other_work with 12 -other_work with 13 -work with 180 -work with 180 -other_work with 14 -other_work with 15 +> other_work with 9 +> work with 160 +> work with 160 +> other_work with 10 +> other_work with 11 +> work with 170 +> work with 170 +> other_work with 12 +> other_work with 13 +> work with 180 +> work with 180 +> other_work with 14 +> other_work with 15 Ended TaskPool-0_Task-0 -Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start -Started TaskPool-0_Task-4 -work with 190 -work with 190 -other_work with 16 -other_work with 20 -other_work with 17 -other_work with 21 -other_work with 18 -other_work with 22 -other_work with 19 -Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start +Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool +Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start +> work with 190 +> work with 190 +> other_work with 16 +> other_work with 17 +> other_work with 20 +> other_work with 18 +> other_work with 21 +Ended TaskPool-0_Task-3 <--- now that only Task-4 of the group remains, Task-5 starts Started TaskPool-0_Task-5 -other_work with 23 -other_work with 30 -other_work with 24 -other_work with 31 -other_work with 25 -other_work with 32 -other_work with 26 -other_work with 33 -other_work with 27 -other_work with 34 -other_work with 28 -other_work with 35 +> other_work with 19 +> other_work with 22 +> other_work with 23 +> other_work with 30 +> other_work with 24 +> other_work with 31 +> other_work with 25 +> other_work with 32 +> other_work with 26 +> other_work with 33 +> other_work with 27 +> other_work with 34 +> other_work with 28 +> other_work with 35 +> other_work with 29 +> other_work with 36 Ended TaskPool-0_Task-4 -other_work with 29 -other_work with 36 -other_work with 37 -other_work with 38 -other_work with 39 -Done. +> other_work with 37 +> other_work with 38 +> other_work with 39 Ended TaskPool-0_Task-5 +> Done. ``` © 2022 Daniil Fajnberg diff --git a/usage/example_server.py b/usage/example_server.py index e5649ce..f4a30e0 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -74,12 +74,12 @@ async def main() -> None: control_server_task.cancel() # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # we can now safely cancel their tasks. - pool.stop_all() pool.lock() + pool.stop_all() # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # we just silently collect their exceptions along with their return values. - await pool.gather(return_exceptions=True) + await pool.gather_and_close(return_exceptions=True) await control_server_task