Compare commits

...

4 Commits

7 changed files with 512 additions and 60 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@
/dist/ /dist/
# Python cache: # Python cache:
__pycache__/ __pycache__/
# Testing:
.coverage

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.0.1 version = 0.0.3
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -14,7 +14,7 @@ class AlreadyCancelled(TaskEnded):
pass pass
class AlreadyFinished(TaskEnded): class AlreadyEnded(TaskEnded):
pass pass
@ -24,3 +24,7 @@ class InvalidTaskID(PoolException):
class PoolStillOpen(PoolException): class PoolStillOpen(PoolException):
pass pass
class NotCoroutine(PoolException):
pass

View File

@ -0,0 +1,24 @@
from asyncio.coroutines import iscoroutinefunction
from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")

View File

@ -1,13 +1,15 @@
import logging import logging
from asyncio import gather from asyncio import gather
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from functools import partial
from math import inf from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions from . import exceptions
from .helpers import execute_optional, star_function
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
@ -39,6 +41,7 @@ class BaseTaskPool:
self._name: str = name self._name: str = name
self._all_tasks_known_flag: Event = Event() self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set() self._all_tasks_known_flag.set()
self._interrupt_flag: Event = Event()
log.debug("%s initialized", str(self)) log.debug("%s initialized", str(self))
def __str__(self) -> str: def __str__(self) -> str:
@ -46,10 +49,22 @@ class BaseTaskPool:
@property @property
def pool_size(self) -> int: def pool_size(self) -> int:
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
return self._pool_size return self._pool_size
@pool_size.setter @pool_size.setter
def pool_size(self, value: int) -> None: def pool_size(self, value: int) -> None:
"""
Sets the maximum number of concurrently running tasks in the pool.
Args:
value:
A non-negative integer.
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
Raises:
`ValueError` if `value` is less than 0.
"""
if value < 0: if value < 0:
raise ValueError("Pool size can not be less than 0") raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value self._enough_room._value = value
@ -107,13 +122,36 @@ class BaseTaskPool:
return f'{self}_Task-{task_id}' return f'{self}_Task-{task_id}'
async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool being cancelled.
Required for keeping track of running/cancelled tasks and proper logging.
Args:
task_id:
The ID of the task that has been cancelled.
custom_callback (optional):
A callback to execute after cancellation of the task.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
log.debug("Cancelling %s ...", self._task_name(task_id)) log.debug("Cancelling %s ...", self._task_name(task_id))
self._cancelled[task_id] = self._running.pop(task_id) self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1 self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None: async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool ending its work.
Required for keeping track of running/cancelled/ended tasks and proper logging.
Also releases room in the task pool for potentially waiting tasks.
Args:
task_id:
The ID of the task that has reached its end.
custom_callback (optional):
A callback to execute after the task has ended.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
try: try:
self._ended[task_id] = self._running.pop(task_id) self._ended[task_id] = self._running.pop(task_id)
except KeyError: except KeyError:
@ -121,10 +159,26 @@ class BaseTaskPool:
self._num_ended += 1 self._num_ended += 1
self._enough_room.release() self._enough_room.release()
log.info("Ended %s", self._task_name(task_id)) log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any: cancel_callback: CancelCallbackT = None) -> Any:
"""
Universal wrapper around every task to be run in the pool.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
task_id:
The ID of the newly created task.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument.
"""
log.info("Started %s", self._task_name(task_id)) log.info("Started %s", self._task_name(task_id))
try: try:
return await awaitable return await awaitable
@ -135,12 +189,34 @@ class BaseTaskPool:
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None, async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int: cancel_callback: CancelCallbackT = None) -> int:
"""
Starts a coroutine as a new task in the pool.
This method blocks, **only if** the pool is full.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
ignore_closed (optional):
If `True`, even if the pool is closed, the task will still be started.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
"""
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed): if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire() await self._enough_room.acquire()
task_id = self._counter
self._counter += 1
try: try:
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task( self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id) name=self._task_name(task_id)
@ -151,44 +227,114 @@ class BaseTaskPool:
return task_id return task_id
def _get_running_task(self, task_id: int) -> Task: def _get_running_task(self, task_id: int) -> Task:
"""
Gets a running task by its task ID.
Args:
task_id: The ID of a task still running within the pool.
Raises:
`asyncio_taskpool.exceptions.AlreadyCancelled` if the task with `task_id` has been (recently) cancelled.
`asyncio_taskpool.exceptions.AlreadyEnded` if the task with `task_id` has ended (recently).
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
"""
try: try:
return self._running[task_id] return self._running[task_id]
except KeyError: except KeyError:
if self._cancelled.get(task_id): if self._cancelled.get(task_id):
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled") raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
if self._ended.get(task_id): if self._ended.get(task_id):
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
def _cancel_task(self, task_id: int, msg: str = None) -> None:
self._get_running_task(task_id).cancel(msg=msg)
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
"""
Cancels the tasks with the specified IDs.
Each task ID must belong to a task still running within the pool. Otherwise one of the following exceptions will
be raised:
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
Args:
task_ids:
Arbitrary number of integers. Each must be an ID of a task still running within the pool.
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids] tasks = [self._get_running_task(task_id) for task_id in task_ids]
for task in tasks: for task in tasks:
task.cancel(msg=msg) task.cancel(msg=msg)
async def cancel_all(self, msg: str = None) -> None: def cancel_all(self, msg: str = None) -> None:
await self._all_tasks_known_flag.wait() """
Cancels all tasks still running within the pool.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, those already running will be cancelled, while the following will **never even start**.
Args:
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
log.warning("%s cancelling all tasks!", str(self))
self._interrupt_flag.set()
for task in self._running.values(): for task in self._running.values():
task.cancel(msg=msg) task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False): async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
"""
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {} self._ended = self._cancelled = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results return results
def close(self) -> None: def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False self._open = False
log.info("%s is closed!", str(self)) log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False): async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
blocking this method. Otherwise it will wait until they all have started.
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
callbacks registered for a task blocks.
Args:
return_exceptions (optional): Passed directly into `gather`.
Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
"""
if self._open: if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait() await self._all_tasks_known_flag.wait()
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {} self._ended = self._cancelled = self._running = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results return results
@ -200,45 +346,47 @@ class TaskPool(BaseTaskPool):
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)) ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
assert isinstance(ids, list)
return ids
@staticmethod async def _next_callback(self, task_id: int, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
reached_end = await self._start_next_task(func, args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
await execute_optional(end_callback, args=(task_id,))
async def _start_next_task(self, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> bool:
if self._interrupt_flag.is_set():
return True
try: try:
arg = next(args_iter) await self._start_task(
star_function(func, next(args_iter), arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._next_callback, self, func=func, args_iter=args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
except StopIteration: except StopIteration:
return return True
if arg_stars == 0: return False
return func(arg)
if arg_stars == 1:
return func(*arg)
if arg_stars == 2:
return func(**arg)
raise ValueError
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._all_tasks_known_flag.is_set(): if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear() self._all_tasks_known_flag.clear()
args_iter = iter(args_iter) args_iter = iter(args_iter)
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None:
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _start_next_coroutine()
await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks): for _ in range(num_tasks):
reached_end = await _start_next_coroutine() reached_end = await self._start_next_task(func, args_iter, arg_stars, end_callback, cancel_callback)
if reached_end: if reached_end:
self._all_tasks_known_flag.set()
break break
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
@ -261,6 +409,8 @@ class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: name: str = None) -> None:
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
self._args: ArgsT = args self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
@ -295,13 +445,3 @@ class SimpleTaskPool(BaseTaskPool):
def stop_all(self) -> List[int]: def stop_all(self) -> List[int]:
return self.stop(self.size) return self.stop(self.size)
async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
if kwargs is None:
kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -1,10 +1,15 @@
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable

View File

@ -1,14 +1,31 @@
import asyncio import asyncio
from unittest import TestCase from asyncio.exceptions import CancelledError
from unittest.mock import PropertyMock, patch from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from asyncio_taskpool import pool from asyncio_taskpool import pool, exceptions
EMPTY_LIST, EMPTY_DICT = [], {} EMPTY_LIST, EMPTY_DICT = [], {}
FOO, BAR = 'foo', 'bar'
class BaseTaskPoolTestCase(TestCase): class TestException(Exception):
pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools') self._pools = getattr(pool.BaseTaskPool, '_pools')
@ -52,6 +69,8 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set()) self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
@ -76,15 +95,15 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(new_size, self.task_pool._pool_size) self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self): def test_is_open(self):
self.task_pool._open = foo = 'foo' self.task_pool._open = FOO
self.assertEqual(foo, self.task_pool.is_open) self.assertEqual(FOO, self.task_pool.is_open)
def test_num_running(self): def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz'] self.task_pool._running = ['foo', 'bar', 'baz']
self.assertEqual(3, self.task_pool.num_running) self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self): def test_num_cancelled(self):
self.task_pool._num_cancelled = 33 self.task_pool._num_cancelled = 3
self.assertEqual(3, self.task_pool.num_cancelled) self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self): def test_num_ended(self):
@ -103,3 +122,261 @@ class BaseTaskPoolTestCase(TestCase):
def test__task_name(self): def test__task_name(self):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock()
mock_execute_optional.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_wrapper(self, mock__task_name: MagicMock,
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
task_id = 42
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
# Cancelled during execution:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertIsNone(output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
mock_coroutine_func.reset_mock(side_effect=True)
mock__task_name.reset_mock()
mock__task_cancellation.reset_mock()
mock__task_ending.reset_mock()
# Not cancelled:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(FOO, output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_not_awaited()
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
mock_is_open.return_value = ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
reset_mocks()
self.task_pool._counter = count
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
del self.task_pool._ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._running = {1: mock_task1, 2: mock_task2}
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_task1.cancel.assert_called_once_with(msg=FOO)
mock_task2.cancel.assert_called_once_with(msg=FOO)
async def test_flush(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._interrupt_flag.set()
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self):
mock_wait = AsyncMock()
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert self.task_pool._open
with self.assertRaises(exceptions.PoolStillOpen):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_not_awaited()
self.task_pool._open = False
def check_assertions() -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_awaited_once_with()
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()
mock_wait.reset_mock()
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()