new interrupt flag for cancel_all to prevent _map() from starting new tasks

This commit is contained in:
Daniil Fajnberg 2022-02-06 16:32:42 +01:00
parent 2d40f5707b
commit e8e13406ea
2 changed files with 25 additions and 16 deletions

View File

@ -39,6 +39,7 @@ class BaseTaskPool:
self._name: str = name
self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set()
self._interrupt_flag: Event = Event()
log.debug("%s initialized", str(self))
def __str__(self) -> str:
@ -188,7 +189,7 @@ class BaseTaskPool:
cancel_callback: CancelCallbackT = None) -> int:
"""
Starts a coroutine as a new task in the pool.
This method blocks, **only if** there the pool is full.
This method blocks, **only if** the pool is full.
Returns/raises whatever the wrapped coroutine does.
Args:
@ -273,25 +274,27 @@ class BaseTaskPool:
for task in tasks:
task.cancel(msg=msg)
async def cancel_all(self, msg: str = None) -> None:
def cancel_all(self, msg: str = None) -> None:
"""
Cancels all tasks still running within the pool.
This method blocks, **only if** a currently unknown number of coroutine functions have been registered to be
run as tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a
number smaller than the number of arguments from `args_iter`.
TODO: Consider changing this behaviour.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, those already running will be cancelled, while the following will **never even start**.
Args:
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
await self._all_tasks_known_flag.wait()
log.warning("%s cancelling all tasks!", str(self))
self._interrupt_flag.set()
for task in self._running.values():
task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the task pool, returns their results, and forgets them.
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
@ -300,6 +303,7 @@ class BaseTaskPool:
"""
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {}
self._interrupt_flag.clear()
return results
def close(self) -> None:
@ -309,16 +313,18 @@ class BaseTaskPool:
async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all tasks from the task pool, returns their results, and forgets them.
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
This method blocks, if a currently unknown number of coroutine functions have been registered to be run as
tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number
smaller than the number of arguments from `args_iter`.
TODO: Consider changing this behaviour.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
blocking this method. Otherwise it will wait until they all have started.
This method also blocks, if any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
callbacks registered for a task blocks.
Args:
return_exceptions (optional): Passed directly into `gather`.
@ -332,6 +338,7 @@ class BaseTaskPool:
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {}
self._interrupt_flag.clear()
return results
@ -369,7 +376,7 @@ class TaskPool(BaseTaskPool):
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None:
if cor is None or self._interrupt_flag.is_set():
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)

View File

@ -52,6 +52,8 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with()