generated from daniil-berg/boilerplate-py
Compare commits
5 Commits
b5eed608b5
...
ac903d9be7
Author | SHA1 | Date | |
---|---|---|---|
ac903d9be7 | |||
e8e13406ea | |||
2d40f5707b | |||
c0c9246b87 | |||
ba0d5fca85 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,3 +8,5 @@
|
||||
/dist/
|
||||
# Python cache:
|
||||
__pycache__/
|
||||
# Testing:
|
||||
.coverage
|
||||
|
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.0.1
|
||||
version = 0.0.2
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -14,7 +14,7 @@ class AlreadyCancelled(TaskEnded):
|
||||
pass
|
||||
|
||||
|
||||
class AlreadyFinished(TaskEnded):
|
||||
class AlreadyEnded(TaskEnded):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ class BaseTaskPool:
|
||||
self._name: str = name
|
||||
self._all_tasks_known_flag: Event = Event()
|
||||
self._all_tasks_known_flag.set()
|
||||
self._interrupt_flag: Event = Event()
|
||||
log.debug("%s initialized", str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -46,10 +47,22 @@ class BaseTaskPool:
|
||||
|
||||
@property
|
||||
def pool_size(self) -> int:
|
||||
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
|
||||
return self._pool_size
|
||||
|
||||
@pool_size.setter
|
||||
def pool_size(self, value: int) -> None:
|
||||
"""
|
||||
Sets the maximum number of concurrently running tasks in the pool.
|
||||
|
||||
Args:
|
||||
value:
|
||||
A non-negative integer.
|
||||
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
|
||||
|
||||
Raises:
|
||||
`ValueError` if `value` is less than 0.
|
||||
"""
|
||||
if value < 0:
|
||||
raise ValueError("Pool size can not be less than 0")
|
||||
self._enough_room._value = value
|
||||
@ -106,20 +119,41 @@ class BaseTaskPool:
|
||||
"""Returns a standardized name for a task with a specific `task_id`."""
|
||||
return f'{self}_Task-{task_id}'
|
||||
|
||||
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
|
||||
async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
|
||||
"""
|
||||
Universal callback to be run upon any task in the pool being cancelled.
|
||||
Required for keeping track of running/cancelled tasks and proper logging.
|
||||
|
||||
Args:
|
||||
task_id:
|
||||
The ID of the task that has been cancelled.
|
||||
custom_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run at the end of this function with the `task_id` as its only positional argument.
|
||||
"""
|
||||
log.debug("Cancelling %s ...", self._task_name(task_id))
|
||||
task = self._running.pop(task_id)
|
||||
assert task is not None
|
||||
self._cancelled[task_id] = task
|
||||
self._cancelled[task_id] = self._running.pop(task_id)
|
||||
self._num_cancelled += 1
|
||||
log.debug("Cancelled %s", self._task_name(task_id))
|
||||
await _execute_function(custom_callback, args=(task_id, ))
|
||||
|
||||
async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
|
||||
task = self._running.pop(task_id, None)
|
||||
if task is None:
|
||||
task = self._cancelled.pop(task_id)
|
||||
self._ended[task_id] = task
|
||||
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
|
||||
"""
|
||||
Universal callback to be run upon any task in the pool ending its work.
|
||||
Required for keeping track of running/cancelled/ended tasks and proper logging.
|
||||
Also releases room in the task pool for potentially waiting tasks.
|
||||
|
||||
Args:
|
||||
task_id:
|
||||
The ID of the task that has reached its end.
|
||||
custom_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run at the end of this function with the `task_id` as its only positional argument.
|
||||
"""
|
||||
try:
|
||||
self._ended[task_id] = self._running.pop(task_id)
|
||||
except KeyError:
|
||||
self._ended[task_id] = self._cancelled.pop(task_id)
|
||||
self._num_ended += 1
|
||||
self._enough_room.release()
|
||||
log.info("Ended %s", self._task_name(task_id))
|
||||
@ -127,22 +161,58 @@ class BaseTaskPool:
|
||||
|
||||
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> Any:
|
||||
"""
|
||||
Universal wrapper around every task to be run in the pool.
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Args:
|
||||
awaitable:
|
||||
The actual coroutine to be run within the task pool.
|
||||
task_id:
|
||||
The ID of the newly created task.
|
||||
end_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
cancel_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
"""
|
||||
log.info("Started %s", self._task_name(task_id))
|
||||
try:
|
||||
return await awaitable
|
||||
except CancelledError:
|
||||
await self._cancel_task(task_id, custom_callback=cancel_callback)
|
||||
await self._task_cancellation(task_id, custom_callback=cancel_callback)
|
||||
finally:
|
||||
await self._end_task(task_id, custom_callback=end_callback)
|
||||
await self._task_ending(task_id, custom_callback=end_callback)
|
||||
|
||||
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> int:
|
||||
"""
|
||||
Starts a coroutine as a new task in the pool.
|
||||
This method blocks, **only if** the pool is full.
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Args:
|
||||
awaitable:
|
||||
The actual coroutine to be run within the task pool.
|
||||
ignore_closed (optional):
|
||||
If `True`, even if the pool is closed, the task will still be started.
|
||||
end_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
cancel_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
It is run with the `task_id` as its only positional argument.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
|
||||
"""
|
||||
if not (self.is_open or ignore_closed):
|
||||
raise exceptions.PoolIsClosed("Cannot start new tasks")
|
||||
await self._enough_room.acquire()
|
||||
task_id = self._counter
|
||||
self._counter += 1
|
||||
try:
|
||||
task_id = self._counter
|
||||
self._counter += 1
|
||||
self._running[task_id] = create_task(
|
||||
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
|
||||
name=self._task_name(task_id)
|
||||
@ -152,41 +222,115 @@ class BaseTaskPool:
|
||||
raise e
|
||||
return task_id
|
||||
|
||||
def _cancel_one(self, task_id: int, msg: str = None) -> None:
|
||||
def _get_running_task(self, task_id: int) -> Task:
|
||||
"""
|
||||
Gets a running task by its task ID.
|
||||
|
||||
Args:
|
||||
task_id: The ID of a task still running within the pool.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.AlreadyCancelled` if the task with `task_id` has been (recently) cancelled.
|
||||
`asyncio_taskpool.exceptions.AlreadyEnded` if the task with `task_id` has ended (recently).
|
||||
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
|
||||
"""
|
||||
try:
|
||||
task = self._running[task_id]
|
||||
return self._running[task_id]
|
||||
except KeyError:
|
||||
if self._cancelled.get(task_id):
|
||||
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
|
||||
if self._ended.get(task_id):
|
||||
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running")
|
||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||
task.cancel(msg=msg)
|
||||
|
||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||
for task_id in task_ids:
|
||||
self._cancel_one(task_id, msg=msg)
|
||||
"""
|
||||
Cancels the tasks with the specified IDs.
|
||||
|
||||
Each task ID must belong to a task still running within the pool. Otherwise one of the following exceptions will
|
||||
be raised:
|
||||
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
|
||||
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
|
||||
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
|
||||
Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
|
||||
|
||||
Args:
|
||||
task_ids:
|
||||
Arbitrary number of integers. Each must be an ID of a task still running within the pool.
|
||||
msg (optional):
|
||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||
"""
|
||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||
for task in tasks:
|
||||
task.cancel(msg=msg)
|
||||
|
||||
def cancel_all(self, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels all tasks still running within the pool.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||
|
||||
Args:
|
||||
msg (optional):
|
||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||
"""
|
||||
log.warning("%s cancelling all tasks!", str(self))
|
||||
self._interrupt_flag.set()
|
||||
for task in self._running.values():
|
||||
task.cancel(msg=msg)
|
||||
|
||||
async def flush(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
|
||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
"""
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = {}
|
||||
if self._interrupt_flag.is_set():
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
def close(self) -> None:
|
||||
"""Disallows any more tasks to be started in the pool."""
|
||||
self._open = False
|
||||
log.info("%s is closed!", str(self))
|
||||
|
||||
async def gather(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
|
||||
|
||||
The `close()` method must have been called prior to this.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||
blocking this method. Otherwise it will wait until they all have started.
|
||||
|
||||
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
|
||||
callbacks registered for a task blocks.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
|
||||
Raises:
|
||||
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
|
||||
"""
|
||||
if self._open:
|
||||
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
|
||||
await self._all_tasks_known_flag.wait()
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = self._running = {}
|
||||
if self._interrupt_flag.is_set():
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
|
||||
@ -224,7 +368,7 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
async def _start_next_coroutine() -> bool:
|
||||
cor = self._get_next_coroutine(func, args_iter, arg_stars)
|
||||
if cor is None:
|
||||
if cor is None or self._interrupt_flag.is_set():
|
||||
self._all_tasks_known_flag.set()
|
||||
return True
|
||||
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
|
||||
|
@ -1,14 +1,31 @@
|
||||
import asyncio
|
||||
from unittest import TestCase
|
||||
from unittest.mock import PropertyMock, patch
|
||||
from asyncio.exceptions import CancelledError
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||
|
||||
from asyncio_taskpool import pool
|
||||
from asyncio_taskpool import pool, exceptions
|
||||
|
||||
|
||||
EMPTY_LIST, EMPTY_DICT = [], {}
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class BaseTaskPoolTestCase(TestCase):
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.log_lvl = pool.log.level
|
||||
pool.log.setLevel(999)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
pool.log.setLevel(cls.log_lvl)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self._pools = getattr(pool.BaseTaskPool, '_pools')
|
||||
|
||||
@ -52,6 +69,8 @@ class BaseTaskPoolTestCase(TestCase):
|
||||
self.assertEqual(self.test_pool_name, self.task_pool._name)
|
||||
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
|
||||
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
|
||||
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
|
||||
self.mock___str__.assert_called_once_with()
|
||||
@ -76,15 +95,15 @@ class BaseTaskPoolTestCase(TestCase):
|
||||
self.assertEqual(new_size, self.task_pool._pool_size)
|
||||
|
||||
def test_is_open(self):
|
||||
self.task_pool._open = foo = 'foo'
|
||||
self.assertEqual(foo, self.task_pool.is_open)
|
||||
self.task_pool._open = FOO
|
||||
self.assertEqual(FOO, self.task_pool.is_open)
|
||||
|
||||
def test_num_running(self):
|
||||
self.task_pool._running = ['foo', 'bar', 'baz']
|
||||
self.assertEqual(3, self.task_pool.num_running)
|
||||
|
||||
def test_num_cancelled(self):
|
||||
self.task_pool._num_cancelled = 33
|
||||
self.task_pool._num_cancelled = 3
|
||||
self.assertEqual(3, self.task_pool.num_cancelled)
|
||||
|
||||
def test_num_ended(self):
|
||||
@ -103,3 +122,244 @@ class BaseTaskPoolTestCase(TestCase):
|
||||
def test__task_name(self):
|
||||
i = 123
|
||||
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
|
||||
|
||||
@patch.object(pool, '_execute_function')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
|
||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||
self.task_pool._num_cancelled = cancelled = 3
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._running)
|
||||
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
|
||||
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
|
||||
@patch.object(pool, '_execute_function')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
|
||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||
self.task_pool._num_ended = ended = 3
|
||||
self.task_pool._enough_room._value = room = 123
|
||||
|
||||
# End running task:
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._running)
|
||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
||||
self.assertEqual(ended + 1, self.task_pool._num_ended)
|
||||
self.assertEqual(room + 1, self.task_pool._enough_room._value)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
mock__task_name.reset_mock()
|
||||
mock__execute_function.reset_mock()
|
||||
|
||||
# End cancelled task:
|
||||
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
|
||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||
self.assertNotIn(task_id, self.task_pool._cancelled)
|
||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
||||
self.assertEqual(ended + 2, self.task_pool._num_ended)
|
||||
self.assertEqual(room + 2, self.task_pool._enough_room._value)
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_task_ending')
|
||||
@patch.object(pool.BaseTaskPool, '_task_cancellation')
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
async def test__task_wrapper(self, mock__task_name: MagicMock,
|
||||
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
|
||||
task_id = 42
|
||||
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
|
||||
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
|
||||
|
||||
# Cancelled during execution:
|
||||
mock_awaitable = mock_coroutine_func()
|
||||
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertIsNone(output)
|
||||
mock_coroutine_func.assert_awaited_once()
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
|
||||
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
|
||||
|
||||
mock_coroutine_func.reset_mock(side_effect=True)
|
||||
mock__task_name.reset_mock()
|
||||
mock__task_cancellation.reset_mock()
|
||||
mock__task_ending.reset_mock()
|
||||
|
||||
# Not cancelled:
|
||||
mock_awaitable = mock_coroutine_func()
|
||||
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(FOO, output)
|
||||
mock_coroutine_func.assert_awaited_once()
|
||||
mock__task_name.assert_called_with(task_id)
|
||||
mock__task_cancellation.assert_not_awaited()
|
||||
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
|
||||
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
|
||||
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
|
||||
def reset_mocks() -> None:
|
||||
mock_is_open.reset_mock()
|
||||
mock__task_name.reset_mock()
|
||||
mock__task_wrapper.reset_mock()
|
||||
mock_create_task.reset_mock()
|
||||
|
||||
mock_create_task.return_value = mock_task = MagicMock()
|
||||
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
|
||||
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
|
||||
self.task_pool._counter = count = 123
|
||||
self.task_pool._enough_room._value = room = 123
|
||||
|
||||
mock_is_open.return_value = ignore_closed = False
|
||||
with self.assertRaises(exceptions.PoolIsClosed):
|
||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(count, self.task_pool._counter)
|
||||
self.assertNotIn(count, self.task_pool._running)
|
||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_not_called()
|
||||
mock__task_wrapper.assert_not_called()
|
||||
mock_create_task.assert_not_called()
|
||||
reset_mocks()
|
||||
|
||||
ignore_closed = True
|
||||
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(count, output)
|
||||
self.assertEqual(count + 1, self.task_pool._counter)
|
||||
self.assertEqual(mock_task, self.task_pool._running[count])
|
||||
self.assertEqual(room - 1, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_called_once_with(count)
|
||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
||||
reset_mocks()
|
||||
self.task_pool._counter = count
|
||||
self.task_pool._enough_room._value = room
|
||||
del self.task_pool._running[count]
|
||||
|
||||
mock_create_task.side_effect = test_exception = TestException()
|
||||
with self.assertRaises(TestException) as e:
|
||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||
self.assertEqual(test_exception, e)
|
||||
self.assertEqual(count + 1, self.task_pool._counter)
|
||||
self.assertNotIn(count, self.task_pool._running)
|
||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
||||
mock_is_open.assert_called_once_with()
|
||||
mock__task_name.assert_called_once_with(count)
|
||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||
def test__get_running_task(self, mock__task_name: MagicMock):
|
||||
task_id, mock_task = 555, MagicMock()
|
||||
self.task_pool._running[task_id] = mock_task
|
||||
output = self.task_pool._get_running_task(task_id)
|
||||
self.assertEqual(mock_task, output)
|
||||
|
||||
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
|
||||
with self.assertRaises(exceptions.AlreadyCancelled):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_called_once_with(task_id)
|
||||
mock__task_name.reset_mock()
|
||||
|
||||
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
|
||||
with self.assertRaises(exceptions.TaskEnded):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_called_once_with(task_id)
|
||||
mock__task_name.reset_mock()
|
||||
|
||||
del self.task_pool._ended[task_id]
|
||||
with self.assertRaises(exceptions.InvalidTaskID):
|
||||
self.task_pool._get_running_task(task_id)
|
||||
mock__task_name.assert_not_called()
|
||||
|
||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||
def test_cancel(self, mock__get_running_task: MagicMock):
|
||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
||||
|
||||
def test_cancel_all(self):
|
||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||
self.task_pool._running = {1: mock_task1, 2: mock_task2}
|
||||
assert not self.task_pool._interrupt_flag.is_set()
|
||||
self.assertIsNone(self.task_pool.cancel_all(FOO))
|
||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
||||
mock_task1.cancel.assert_called_once_with(msg=FOO)
|
||||
mock_task2.cancel.assert_called_once_with(msg=FOO)
|
||||
|
||||
async def test_flush(self):
|
||||
test_exception = TestException()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._interrupt_flag.set()
|
||||
output = await self.task_pool.flush(return_exceptions=True)
|
||||
self.assertListEqual([FOO, test_exception], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
output = await self.task_pool.flush(return_exceptions=True)
|
||||
self.assertListEqual([FOO, test_exception], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
|
||||
def test_close(self):
|
||||
assert self.task_pool._open
|
||||
self.task_pool.close()
|
||||
self.assertFalse(self.task_pool._open)
|
||||
|
||||
async def test_gather(self):
|
||||
mock_wait = AsyncMock()
|
||||
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
|
||||
test_exception = TestException()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
||||
mock_running_func = AsyncMock(return_value=BAR)
|
||||
self.task_pool._ended = ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._running = running = {789: mock_running_func()}
|
||||
self.task_pool._interrupt_flag.set()
|
||||
|
||||
assert self.task_pool._open
|
||||
with self.assertRaises(exceptions.PoolStillOpen):
|
||||
await self.task_pool.gather()
|
||||
self.assertDictEqual(self.task_pool._ended, ended)
|
||||
self.assertDictEqual(self.task_pool._cancelled, cancelled)
|
||||
self.assertDictEqual(self.task_pool._running, running)
|
||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
||||
mock_wait.assert_not_awaited()
|
||||
|
||||
self.task_pool._open = False
|
||||
|
||||
def check_assertions() -> None:
|
||||
self.assertListEqual([FOO, test_exception, BAR], output)
|
||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
||||
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
mock_wait.assert_awaited_once_with()
|
||||
|
||||
output = await self.task_pool.gather(return_exceptions=True)
|
||||
check_assertions()
|
||||
mock_wait.reset_mock()
|
||||
|
||||
self.task_pool._ended = {123: mock_ended_func()}
|
||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._running = {789: mock_running_func()}
|
||||
output = await self.task_pool.gather(return_exceptions=True)
|
||||
check_assertions()
|
||||
|
Loading…
Reference in New Issue
Block a user