generated from daniil-berg/boilerplate-py
Compare commits
9 Commits
3a27040a54
...
v0.0.2-lw
Author | SHA1 | Date | |
---|---|---|---|
ac903d9be7 | |||
e8e13406ea | |||
2d40f5707b | |||
c0c9246b87 | |||
ba0d5fca85 | |||
b5eed608b5 | |||
2f0b08edf0 | |||
a68e61dfa7 | |||
9ec5359fd6 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,3 +8,5 @@
|
||||
/dist/
|
||||
# Python cache:
|
||||
__pycache__/
|
||||
# Testing:
|
||||
.coverage
|
||||
|
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.0.1
|
||||
version = 0.0.2
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -14,7 +14,7 @@ class AlreadyCancelled(TaskEnded):
|
||||
pass
|
||||
|
||||
|
||||
class AlreadyFinished(TaskEnded):
|
||||
class AlreadyEnded(TaskEnded):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -2,164 +2,348 @@ import logging
|
||||
from asyncio import gather
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.locks import Event
|
||||
from asyncio.locks import Event, Semaphore
|
||||
from asyncio.tasks import Task, create_task
|
||||
from math import inf
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
from . import exceptions
|
||||
from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT
|
||||
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTaskPool:
|
||||
"""The base class for task pools. Not intended to be used directly."""
|
||||
_pools: List['BaseTaskPool'] = []
|
||||
|
||||
@classmethod
|
||||
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
|
||||
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
|
||||
cls._pools.append(pool)
|
||||
return len(cls._pools) - 1
|
||||
|
||||
# TODO: Make use of `max_running`
|
||||
def __init__(self, max_running: int = inf, name: str = None) -> None:
|
||||
self._max_running: int = max_running
|
||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
||||
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
|
||||
self._enough_room: Semaphore = Semaphore()
|
||||
self.pool_size = pool_size
|
||||
self._open: bool = True
|
||||
self._counter: int = 0
|
||||
self._running: Dict[int, Task] = {}
|
||||
self._cancelled: Dict[int, Task] = {}
|
||||
self._ended: Dict[int, Task] = {}
|
||||
self._num_cancelled: int = 0
|
||||
self._num_ended: int = 0
|
||||
self._idx: int = self._add_pool(self)
|
||||
self._name: str = name
|
||||
self._all_tasks_known_flag: Event = Event()
|
||||
self._all_tasks_known_flag.set()
|
||||
self._more_allowed_flag: Event = Event()
|
||||
self._check_more_allowed()
|
||||
self._interrupt_flag: Event = Event()
|
||||
log.debug("%s initialized", str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.__class__.__name__}-{self._name or self._idx}'
|
||||
|
||||
@property
|
||||
def max_running(self) -> int:
|
||||
return self._max_running
|
||||
def pool_size(self) -> int:
|
||||
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
|
||||
return self._pool_size
|
||||
|
||||
@pool_size.setter
|
||||
def pool_size(self, value: int) -> None:
|
||||
"""
|
||||
Sets the maximum number of concurrently running tasks in the pool.
|
||||
|
||||
Args:
|
||||
value:
|
||||
A non-negative integer.
|
||||
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
|
||||
|
||||
Raises:
|
||||
`ValueError` if `value` is less than 0.
|
||||
"""
|
||||
if value < 0:
|
||||
raise ValueError("Pool size can not be less than 0")
|
||||
self._enough_room._value = value
|
||||
self._pool_size = value
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Returns `True` if more the pool has not been closed yet."""
|
||||
return self._open
|
||||
|
||||
@property
|
||||
def num_running(self) -> int:
|
||||
"""
|
||||
Returns the number of tasks in the pool that are (at that moment) still running.
|
||||
At the moment a task's `end_callback` is fired, it is no longer considered to be running.
|
||||
"""
|
||||
return len(self._running)
|
||||
|
||||
@property
|
||||
def num_cancelled(self) -> int:
|
||||
return len(self._cancelled)
|
||||
"""
|
||||
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
|
||||
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
|
||||
"""
|
||||
return self._num_cancelled
|
||||
|
||||
@property
|
||||
def num_ended(self) -> int:
|
||||
return len(self._ended)
|
||||
"""
|
||||
Returns the number of tasks started through the pool that have stopped running (up until that moment).
|
||||
At the moment a task's `end_callback` is fired, it is considered ended.
|
||||
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
|
||||
does it then actually end.
|
||||
"""
|
||||
return self._num_ended
|
||||
|
||||
@property
|
||||
def num_finished(self) -> int:
|
||||
return self.num_ended - self.num_cancelled
|
||||
"""
|
||||
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
|
||||
"""
|
||||
return self._num_ended - self._num_cancelled + len(self._cancelled)
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return not self._more_allowed_flag.is_set()
|
||||
|
||||
def _check_more_allowed(self) -> None:
|
||||
if self.is_full and self.num_running < self.max_running:
|
||||
self._more_allowed_flag.set()
|
||||
elif not self.is_full and self.num_running >= self.max_running:
|
||||
self._more_allowed_flag.clear()
|
||||
"""
|
||||
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
|
||||
When the pool is full, any call to start a new task within it will block.
|
||||
"""
|
||||
return self._enough_room.locked()
|
||||
|
||||
# TODO: Consider adding task group names
|
||||
def _task_name(self, task_id: int) -> str:
|
||||
"""Returns a standardized name for a task with a specific `task_id`."""
|
||||
return f'{self}_Task-{task_id}'
|
||||
|
||||
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
|
||||
async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Universal callback to be run upon any task in the pool being cancelled.
|
||||
Required for keeping track of running/cancelled tasks and proper logging.
|
||||
|
||||
Args:
|
||||
task_id:
|
||||
The ID of the task that has been cancelled.
|
||||
custom_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run at the end of this function with the `task_id` as its only positional argument.
|
||||
"""
|
||||
log.debug("Cancelling %s ...", self._task_name(task_id))
|
||||
task = self._running.pop(task_id)
|
||||
assert task is not None
|
||||
self._cancelled[task_id] = task
|
||||
await _execute_function(custom_callback, args=(task_id, ))
|
||||
self._cancelled[task_id] = self._running.pop(task_id)
|
||||
self._num_cancelled += 1
|
||||
log.debug("Cancelled %s", self._task_name(task_id))
|
||||
|
||||
async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None:
|
||||
task = self._running.pop(task_id, None)
|
||||
if task is None:
|
||||
task = self._cancelled[task_id]
|
||||
self._ended[task_id] = task
|
||||
await _execute_function(custom_callback, args=(task_id, ))
|
||||
self._check_more_allowed()
|
||||
log.info("Ended %s", self._task_name(task_id))
|
||||
|
||||
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None,
|
||||
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
|
||||
"""
|
||||
Universal callback to be run upon any task in the pool ending its work.
|
||||
Required for keeping track of running/cancelled/ended tasks and proper logging.
|
||||
Also releases room in the task pool for potentially waiting tasks.
|
||||
|
||||
Args:
|
||||
task_id:
|
||||
The ID of the task that has reached its end.
|
||||
custom_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run at the end of this function with the `task_id` as its only positional argument.
|
||||
"""
|
||||
try:
|
||||
self._ended[task_id] = self._running.pop(task_id)
|
||||
except KeyError:
|
||||
self._ended[task_id] = self._cancelled.pop(task_id)
|
||||
self._num_ended += 1
|
||||
self._enough_room.release()
|
||||
log.info("Ended %s", self._task_name(task_id))
|
||||
await _execute_function(custom_callback, args=(task_id, ))
|
||||
|
||||
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> Any:
|
||||
"""
|
||||
Universal wrapper around every task to be run in the pool.
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Args:
|
||||
awaitable:
|
||||
The actual coroutine to be run within the task pool.
|
||||
task_id:
|
||||
The ID of the newly created task.
|
||||
end_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
cancel_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
"""
|
||||
log.info("Started %s", self._task_name(task_id))
|
||||
try:
|
||||
return await awaitable
|
||||
except CancelledError:
|
||||
await self._cancel_task(task_id, custom_callback=cancel_callback)
|
||||
await self._task_cancellation(task_id, custom_callback=cancel_callback)
|
||||
finally:
|
||||
await self._end_task(task_id, custom_callback=final_callback)
|
||||
await self._task_ending(task_id, custom_callback=end_callback)
|
||||
|
||||
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None,
|
||||
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> int:
|
||||
if not (self._open or ignore_closed):
|
||||
"""
|
||||
Starts a coroutine as a new task in the pool.
|
||||
This method blocks, **only if** the pool is full.
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Args:
|
||||
awaitable:
|
||||
The actual coroutine to be run within the task pool.
|
||||
ignore_closed (optional):
|
||||
If `True`, even if the pool is closed, the task will still be started.
|
||||
end_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
cancel_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
|
||||
"""
|
||||
if not (self.is_open or ignore_closed):
|
||||
raise exceptions.PoolIsClosed("Cannot start new tasks")
|
||||
await self._enough_room.acquire()
|
||||
task_id = self._counter
|
||||
self._counter += 1
|
||||
try:
|
||||
self._running[task_id] = create_task(
|
||||
self._task_wrapper(awaitable, task_id, final_callback, cancel_callback),
|
||||
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
|
||||
name=self._task_name(task_id)
|
||||
)
|
||||
self._check_more_allowed()
|
||||
except Exception as e:
|
||||
self._enough_room.release()
|
||||
raise e
|
||||
return task_id
|
||||
|
||||
def _cancel_one(self, task_id: int, msg: str = None) -> None:
|
||||
def _get_running_task(self, task_id: int) -> Task:
|
||||
"""
|
||||
Gets a running task by its task ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of a task still running within the pool.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.AlreadyCancelled` if the task with `task_id` has been (recently) cancelled.
|
||||
`asyncio_taskpool.exceptions.AlreadyEnded` if the task with `task_id` has ended (recently).
|
||||
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
|
||||
"""
|
||||
try:
|
||||
task = self._running[task_id]
|
||||
return self._running[task_id]
|
||||
except KeyError:
|
||||
if self._cancelled.get(task_id):
|
||||
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
|
||||
if self._ended.get(task_id):
|
||||
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running")
|
||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||
task.cancel(msg=msg)
|
||||
|
||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||
for task_id in task_ids:
|
||||
self._cancel_one(task_id, msg=msg)
|
||||
"""
|
||||
Cancels the tasks with the specified IDs.
|
||||
|
||||
Each task ID must belong to a task still running within the pool. Otherwise one of the following exceptions will
|
||||
be raised:
|
||||
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
|
||||
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
|
||||
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
|
||||
Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
|
||||
|
||||
Args:
|
||||
task_ids:
|
||||
Arbitrary number of integers. Each must be an ID of a task still running within the pool.
|
||||
msg (optional):
|
||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||
"""
|
||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||
for task in tasks:
|
||||
task.cancel(msg=msg)
|
||||
|
||||
def cancel_all(self, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels all tasks still running within the pool.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||
|
||||
Args:
|
||||
msg (optional):
|
||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||
"""
|
||||
log.warning("%s cancelling all tasks!", str(self))
|
||||
self._interrupt_flag.set()
|
||||
for task in self._running.values():
|
||||
task.cancel(msg=msg)
|
||||
|
||||
async def flush(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
|
||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
"""
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = {}
|
||||
if self._interrupt_flag.is_set():
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
def close(self) -> None:
|
||||
"""Disallows any more tasks to be started in the pool."""
|
||||
self._open = False
|
||||
log.info("%s is closed!", str(self))
|
||||
|
||||
async def gather(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
|
||||
|
||||
The `close()` method must have been called prior to this.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||
blocking this method. Otherwise it will wait until they all have started.
|
||||
|
||||
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
|
||||
callbacks registered for a task blocks.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
|
||||
"""
|
||||
if self._open:
|
||||
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
|
||||
await self._all_tasks_known_flag.wait()
|
||||
results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions)
|
||||
self._running = self._cancelled = self._ended = {}
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = self._running = {}
|
||||
if self._interrupt_flag.is_set():
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
|
||||
class TaskPool(BaseTaskPool):
|
||||
def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
|
||||
async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback)
|
||||
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
|
||||
return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num))
|
||||
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
|
||||
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))
|
||||
|
||||
@staticmethod
|
||||
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
|
||||
@ -175,54 +359,54 @@ class TaskPool(BaseTaskPool):
|
||||
return func(**arg)
|
||||
raise ValueError
|
||||
|
||||
def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
|
||||
if self._all_tasks_known_flag.is_set():
|
||||
self._all_tasks_known_flag.clear()
|
||||
args_iter = iter(args_iter)
|
||||
|
||||
def _start_next_coroutine() -> bool:
|
||||
async def _start_next_coroutine() -> bool:
|
||||
cor = self._get_next_coroutine(func, args_iter, arg_stars)
|
||||
if cor is None:
|
||||
if cor is None or self._interrupt_flag.is_set():
|
||||
self._all_tasks_known_flag.set()
|
||||
return True
|
||||
self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback)
|
||||
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
|
||||
return False
|
||||
|
||||
async def _start_next(task_id: int) -> None:
|
||||
await _execute_function(final_callback, args=(task_id, ))
|
||||
_start_next_coroutine()
|
||||
await _start_next_coroutine()
|
||||
await _execute_function(end_callback, args=(task_id, ))
|
||||
|
||||
for _ in range(num_tasks):
|
||||
reached_end = _start_next_coroutine()
|
||||
reached_end = await _start_next_coroutine()
|
||||
if reached_end:
|
||||
break
|
||||
|
||||
def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
final_callback=final_callback, cancel_callback=cancel_callback)
|
||||
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
final_callback=final_callback, cancel_callback=cancel_callback)
|
||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
final_callback=final_callback, cancel_callback=cancel_callback)
|
||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||
|
||||
|
||||
class SimpleTaskPool(BaseTaskPool):
|
||||
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None,
|
||||
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
|
||||
name: str = None) -> None:
|
||||
self._func: CoroutineFunc = func
|
||||
self._args: ArgsT = args
|
||||
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
|
||||
self._final_callback: FinalCallbackT = final_callback
|
||||
self._end_callback: EndCallbackT = end_callback
|
||||
self._cancel_callback: CancelCallbackT = cancel_callback
|
||||
super().__init__(name=name)
|
||||
|
||||
@ -234,12 +418,12 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
def size(self) -> int:
|
||||
return self.num_running
|
||||
|
||||
def _start_one(self) -> int:
|
||||
return self._start_task(self._func(*self._args, **self._kwargs),
|
||||
final_callback=self._final_callback, cancel_callback=self._cancel_callback)
|
||||
async def _start_one(self) -> int:
|
||||
return await self._start_task(self._func(*self._args, **self._kwargs),
|
||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
||||
|
||||
def start(self, num: int = 1) -> List[int]:
|
||||
return [self._start_one() for _ in range(num)]
|
||||
async def start(self, num: int = 1) -> List[int]:
|
||||
return [await self._start_one() for _ in range(num)]
|
||||
|
||||
def stop(self, num: int = 1) -> List[int]:
|
||||
num = min(num, self.size)
|
||||
|
@ -45,11 +45,11 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
||||
self._server_kwargs = server_kwargs
|
||||
self._server: Optional[AbstractServer] = None
|
||||
|
||||
def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||
if num is None:
|
||||
num = 1
|
||||
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
|
||||
writer.write(str(self._pool.start(num)).encode())
|
||||
writer.write(str(await self._pool.start(num)).encode())
|
||||
|
||||
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||
if num is None:
|
||||
@ -78,7 +78,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
||||
break
|
||||
cmd, arg = get_cmd_arg(msg)
|
||||
if cmd == constants.CMD_START:
|
||||
self._start_tasks(writer, arg)
|
||||
await self._start_tasks(writer, arg)
|
||||
elif cmd == constants.CMD_STOP:
|
||||
self._stop_tasks(writer, arg)
|
||||
elif cmd == constants.CMD_STOP_ALL:
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
|
||||
ArgsT = Iterable[Any]
|
||||
KwArgsT = Mapping[str, Any]
|
||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
FinalCallbackT = Callable
|
||||
EndCallbackT = Callable
|
||||
CancelCallbackT = Callable
|
||||
|
||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||
|
@ -1,35 +1,52 @@
|
||||
import asyncio
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, PropertyMock, patch, call
|
||||
from asyncio.exceptions import CancelledError
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||
|
||||
from asyncio_taskpool import pool
|
||||
from asyncio_taskpool import pool, exceptions
|
||||
|
||||
|
||||
EMPTY_LIST, EMPTY_DICT = [], {}
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class BaseTaskPoolTestCase(TestCase):
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = pool.log.level
|
||||
pool.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
pool.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self._pools = getattr(pool.BaseTaskPool, '_pools')
|
||||
|
||||
# These three methods are called during initialization, so we mock them by default during setup
|
||||
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
|
||||
self._check_more_allowed_patcher = patch.object(pool.BaseTaskPool, '_check_more_allowed')
|
||||
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
|
||||
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
|
||||
self.mock__add_pool = self._add_pool_patcher.start()
|
||||
self.mock__check_more_allowed = self._check_more_allowed_patcher.start()
|
||||
self.mock_pool_size = self.pool_size_patcher.start()
|
||||
self.mock___str__ = self.__str___patcher.start()
|
||||
self.mock__add_pool.return_value = self.mock_idx = 123
|
||||
self.mock___str__.return_value = self.mock_str = 'foobar'
|
||||
|
||||
# Test pool parameters:
|
||||
self.mock_pool_params = {'max_running': 420, 'name': 'test123'}
|
||||
self.task_pool = pool.BaseTaskPool(**self.mock_pool_params)
|
||||
self.test_pool_size, self.test_pool_name = 420, 'test123'
|
||||
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
setattr(pool.TaskPool, '_pools', self._pools)
|
||||
self._add_pool_patcher.stop()
|
||||
self._check_more_allowed_patcher.stop()
|
||||
self.pool_size_patcher.stop()
|
||||
self.__str___patcher.stop()
|
||||
|
||||
def test__add_pool(self):
|
||||
@ -40,105 +57,309 @@ class BaseTaskPoolTestCase(TestCase):
|
||||
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
|
||||
|
||||
def test_init(self):
|
||||
for key, value in self.mock_pool_params.items():
|
||||
self.assertEqual(value, getattr(self.task_pool, f'_{key}'))
|
||||
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
|
||||
self.assertTrue(self.task_pool._open)
|
||||
self.assertEqual(0, self.task_pool._counter)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
|
||||
self.assertEqual(0, self.task_pool._num_cancelled)
|
||||
self.assertEqual(0, self.task_pool._num_ended)
|
||||
self.assertEqual(self.mock_idx, self.task_pool._idx)
|
||||
self.assertEqual(self.test_pool_name, self.task_pool._name)
|
||||
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
|
||||
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
|
||||
self.assertIsInstance(self.task_pool._more_allowed_flag, asyncio.locks.Event)
|
||||
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||
self.mock__check_more_allowed.assert_called_once_with()
|
||||
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
|
||||
self.mock___str__.assert_called_once_with()
|
||||
|
||||
def test___str__(self):
|
||||
self.__str___patcher.stop()
|
||||
expected_str = f'{pool.BaseTaskPool.__name__}-{self.mock_pool_params["name"]}'
|
||||
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
|
||||
self.assertEqual(expected_str, str(self.task_pool))
|
||||
setattr(self.task_pool, '_name', None)
|
||||
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
|
||||
self.assertEqual(expected_str, str(self.task_pool))
|
||||
|
||||
def test_max_running(self):
|
||||
self.task_pool._max_running = foo = 'foo'
|
||||
self.assertEqual(foo, self.task_pool.max_running)
|
||||
def test_pool_size(self):
|
||||
self.pool_size_patcher.stop()
|
||||
self.task_pool._pool_size = self.test_pool_size
|
||||
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.task_pool.pool_size = -1
|
||||
|
||||
self.task_pool.pool_size = new_size = 69
|
||||
self.assertEqual(new_size, self.task_pool._pool_size)
|
||||
|
||||
def test_is_open(self):
|
||||
self.task_pool._open = foo = 'foo'
|
||||
self.assertEqual(foo, self.task_pool.is_open)
|
||||
self.task_pool._open = FOO
|
||||
self.assertEqual(FOO, self.task_pool.is_open)
|
||||
|
||||
def test_num_running(self):
|
||||
self.task_pool._running = ['foo', 'bar', 'baz']
|
||||
self.assertEqual(3, self.task_pool.num_running)
|
||||
|
||||
def test_num_cancelled(self):
|
||||
self.task_pool._cancelled = ['foo', 'bar', 'baz']
|
||||
self.task_pool._num_cancelled = 3
|
||||
self.assertEqual(3, self.task_pool.num_cancelled)
|
||||
|
||||
def test_num_ended(self):
|
||||
self.task_pool._ended = ['foo', 'bar', 'baz']
|
||||
self.task_pool._num_ended = 3
|
||||
self.assertEqual(3, self.task_pool.num_ended)
|
||||
|
||||
@patch.object(pool.BaseTaskPool, 'num_ended', new_callable=PropertyMock)
|
||||
@patch.object(pool.BaseTaskPool, 'num_cancelled', new_callable=PropertyMock)
|
||||
def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock):
|
||||
mock_num_cancelled.return_value = cancelled = 69
|
||||
mock_num_ended.return_value = ended = 420
|
||||
self.assertEqual(ended - cancelled, self.task_pool.num_finished)
|
||||
def test_num_finished(self):
|
||||
self.task_pool._num_cancelled = cancelled = 69
|
||||
self.task_pool._num_ended = ended = 420
|
||||
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
|
||||
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
|
||||
|
||||
def test_is_full(self):
|
||||
self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full)
|
||||
|
||||
@patch.object(pool.BaseTaskPool, 'max_running', new_callable=PropertyMock)
|
||||
@patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock)
|
||||
@patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock)
|
||||
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock,
|
||||
mock_max_running: MagicMock):
|
||||
def reset_mocks():
|
||||
mock_is_full.reset_mock()
|
||||
mock_num_running.reset_mock()
|
||||
mock_max_running.reset_mock()
|
||||
self._check_more_allowed_patcher.stop()
|
||||
|
||||
# Just reaching limit, we expect flag to become unset:
|
||||
mock_is_full.return_value = False
|
||||
mock_num_running.return_value = mock_max_running.return_value = 420
|
||||
self.task_pool._more_allowed_flag.clear()
|
||||
self.task_pool._check_more_allowed()
|
||||
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
|
||||
mock_is_full.assert_has_calls([call(), call()])
|
||||
mock_num_running.assert_called_once_with()
|
||||
mock_max_running.assert_called_once_with()
|
||||
reset_mocks()
|
||||
|
||||
# Already at limit, we expect nothing to change:
|
||||
mock_is_full.return_value = True
|
||||
self.task_pool._check_more_allowed()
|
||||
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
|
||||
mock_is_full.assert_has_calls([call(), call()])
|
||||
mock_num_running.assert_called_once_with()
|
||||
mock_max_running.assert_called_once_with()
|
||||
reset_mocks()
|
||||
|
||||
# Just finished a task, we expect flag to become set:
|
||||
mock_num_running.return_value = 419
|
||||
self.task_pool._check_more_allowed()
|
||||
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
|
||||
mock_is_full.assert_called_once_with()
|
||||
mock_num_running.assert_called_once_with()
|
||||
mock_max_running.assert_called_once_with()
|
||||
reset_mocks()
|
||||
|
||||
# In this state we expect the flag to remain unchanged change:
|
||||
mock_is_full.return_value = False
|
||||
self.task_pool._check_more_allowed()
|
||||
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
|
||||
mock_is_full.assert_has_calls([call(), call()])
|
||||
mock_num_running.assert_called_once_with()
|
||||
mock_max_running.assert_called_once_with()
|
||||
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
|
||||
|
||||
def test__task_name(self):
|
||||
i = 123
|
||||
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
|
||||
|
||||
@patch.object(pool, '_execute_function')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
|
||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||
self.task_pool._num_cancelled = cancelled = 3
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._running)
|
||||
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
|
||||
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
|
||||
@patch.object(pool, '_execute_function')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
|
||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||
self.task_pool._num_ended = ended = 3
|
||||
self.task_pool._enough_room._value = room = 123
|
||||
|
||||
# End running task:
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._running)
|
||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
||||
self.assertEqual(ended + 1, self.task_pool._num_ended)
|
||||
self.assertEqual(room + 1, self.task_pool._enough_room._value)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
mock__task_name.reset_mock()
|
||||
mock__execute_function.reset_mock()
|
||||
|
||||
# End cancelled task:
|
||||
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
|
||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._cancelled)
|
||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
||||
self.assertEqual(ended + 2, self.task_pool._num_ended)
|
||||
self.assertEqual(room + 2, self.task_pool._enough_room._value)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_task_ending')
|
||||
@patch.object(pool.BaseTaskPool, '_task_cancellation')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_wrapper(self, mock__task_name: MagicMock,
|
||||
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
|
||||
task_id = 42
|
||||
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
|
||||
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
|
||||
|
||||
# Cancelled during execution:
|
||||
mock_awaitable = mock_coroutine_func()
|
||||
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertIsNone(output)
|
||||
mock_coroutine_func.assert_awaited_once()
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
|
||||
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
|
||||
|
||||
mock_coroutine_func.reset_mock(side_effect=True)
|
||||
mock__task_name.reset_mock()
|
||||
mock__task_cancellation.reset_mock()
|
||||
mock__task_ending.reset_mock()
|
||||
|
||||
# Not cancelled:
|
||||
mock_awaitable = mock_coroutine_func()
|
||||
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(FOO, output)
|
||||
mock_coroutine_func.assert_awaited_once()
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__task_cancellation.assert_not_awaited()
|
||||
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
|
||||
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
|
||||
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
|
||||
def reset_mocks() -> None:
|
||||
mock_is_open.reset_mock()
|
||||
mock__task_name.reset_mock()
|
||||
mock__task_wrapper.reset_mock()
|
||||
mock_create_task.reset_mock()
|
||||
|
||||
mock_create_task.return_value = mock_task = MagicMock()
|
||||
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
|
||||
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
|
||||
self.task_pool._counter = count = 123
|
||||
self.task_pool._enough_room._value = room = 123
|
||||
|
||||
mock_is_open.return_value = ignore_closed = False
|
||||
with self.assertRaises(exceptions.PoolIsClosed):
|
||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(count, self.task_pool._counter)
|
||||
self.assertNotIn(count, self.task_pool._running)
|
||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_not_called()
|
||||
mock__task_wrapper.assert_not_called()
|
||||
mock_create_task.assert_not_called()
|
||||
reset_mocks()
|
||||
|
||||
ignore_closed = True
|
||||
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(count, output)
|
||||
self.assertEqual(count + 1, self.task_pool._counter)
|
||||
self.assertEqual(mock_task, self.task_pool._running[count])
|
||||
self.assertEqual(room - 1, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_called_once_with(count)
|
||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
||||
reset_mocks()
|
||||
self.task_pool._counter = count
|
||||
self.task_pool._enough_room._value = room
|
||||
del self.task_pool._running[count]
|
||||
|
||||
mock_create_task.side_effect = test_exception = TestException()
|
||||
with self.assertRaises(TestException) as e:
|
||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(test_exception, e)
|
||||
self.assertEqual(count + 1, self.task_pool._counter)
|
||||
self.assertNotIn(count, self.task_pool._running)
|
||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_called_once_with(count)
|
||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
def test__get_running_task(self, mock__task_name: MagicMock):
|
||||
task_id, mock_task = 555, MagicMock()
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
output = self.task_pool._get_running_task(task_id)
|
||||
self.assertEqual(mock_task, output)
|
||||
|
||||
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
|
||||
with self.assertRaises(exceptions.AlreadyCancelled):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_called_once_with(task_id)
|
||||
mock__task_name.reset_mock()
|
||||
|
||||
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
|
||||
with self.assertRaises(exceptions.TaskEnded):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_called_once_with(task_id)
|
||||
mock__task_name.reset_mock()
|
||||
|
||||
del self.task_pool._ended[task_id]
|
||||
with self.assertRaises(exceptions.InvalidTaskID):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_not_called()
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||
def test_cancel(self, mock__get_running_task: MagicMock):
|
||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
||||
|
||||
def test_cancel_all(self):
|
||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||
self.task_pool._running = {1: mock_task1, 2: mock_task2}
|
||||
assert not self.task_pool._interrupt_flag.is_set()
|
||||
self.assertIsNone(self.task_pool.cancel_all(FOO))
|
||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
||||
mock_task1.cancel.assert_called_once_with(msg=FOO)
|
||||
mock_task2.cancel.assert_called_once_with(msg=FOO)
|
||||
|
||||
async def test_flush(self):
|
||||
test_exception = TestException()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._interrupt_flag.set()
|
||||
output = await self.task_pool.flush(return_exceptions=True)
|
||||
self.assertListEqual([FOO, test_exception], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
output = await self.task_pool.flush(return_exceptions=True)
|
||||
self.assertListEqual([FOO, test_exception], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
|
||||
def test_close(self):
|
||||
assert self.task_pool._open
|
||||
self.task_pool.close()
|
||||
self.assertFalse(self.task_pool._open)
|
||||
|
||||
async def test_gather(self):
|
||||
mock_wait = AsyncMock()
|
||||
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
|
||||
test_exception = TestException()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
||||
mock_running_func = AsyncMock(return_value=BAR)
|
||||
self.task_pool._ended = ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._running = running = {789: mock_running_func()}
|
||||
self.task_pool._interrupt_flag.set()
|
||||
|
||||
assert self.task_pool._open
|
||||
with self.assertRaises(exceptions.PoolStillOpen):
|
||||
await self.task_pool.gather()
|
||||
self.assertDictEqual(self.task_pool._ended, ended)
|
||||
self.assertDictEqual(self.task_pool._cancelled, cancelled)
|
||||
self.assertDictEqual(self.task_pool._running, running)
|
||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
||||
mock_wait.assert_not_awaited()
|
||||
|
||||
self.task_pool._open = False
|
||||
|
||||
def check_assertions() -> None:
|
||||
self.assertListEqual([FOO, test_exception, BAR], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
mock_wait.assert_awaited_once_with()
|
||||
|
||||
output = await self.task_pool.gather(return_exceptions=True)
|
||||
check_assertions()
|
||||
mock_wait.reset_mock()
|
||||
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._running = {789: mock_running_func()}
|
||||
output = await self.task_pool.gather(return_exceptions=True)
|
||||
check_assertions()
|
||||
|
@ -33,9 +33,9 @@ async def work(n: int) -> None:
|
||||
|
||||
async def main() -> None:
|
||||
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
|
||||
pool.start(3) # launches work tasks 0, 1, and 2
|
||||
await pool.start(3) # launches work tasks 0, 1, and 2
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
pool.start() # launches work task 3
|
||||
await pool.start() # launches work task 3
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
pool.stop(2) # cancels tasks 3 and 2
|
||||
pool.close() # required for the last line
|
||||
|
@ -44,7 +44,7 @@ async def main() -> None:
|
||||
for item in range(100):
|
||||
q.put_nowait(item)
|
||||
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
|
||||
pool.start(3) # launches three worker tasks
|
||||
await pool.start(3) # launches three worker tasks
|
||||
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
|
||||
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||
await q.join()
|
||||
|
Reference in New Issue
Block a user