generated from daniil-berg/boilerplate-py
new interrupt flag for cancel_all
to prevent _map()
from starting new tasks
This commit is contained in:
parent
2d40f5707b
commit
e8e13406ea
@ -39,6 +39,7 @@ class BaseTaskPool:
|
||||
self._name: str = name
|
||||
self._all_tasks_known_flag: Event = Event()
|
||||
self._all_tasks_known_flag.set()
|
||||
self._interrupt_flag: Event = Event()
|
||||
log.debug("%s initialized", str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -188,7 +189,7 @@ class BaseTaskPool:
|
||||
cancel_callback: CancelCallbackT = None) -> int:
|
||||
"""
|
||||
Starts a coroutine as a new task in the pool.
|
||||
This method blocks, **only if** there the pool is full.
|
||||
This method blocks, **only if** the pool is full.
|
||||
Returns/raises whatever the wrapped coroutine does.
|
||||
|
||||
Args:
|
||||
@ -273,25 +274,27 @@ class BaseTaskPool:
|
||||
for task in tasks:
|
||||
task.cancel(msg=msg)
|
||||
|
||||
async def cancel_all(self, msg: str = None) -> None:
|
||||
def cancel_all(self, msg: str = None) -> None:
|
||||
"""
|
||||
Cancels all tasks still running within the pool.
|
||||
This method blocks, **only if** a currently unknown number of coroutine functions have been registered to be
|
||||
run as tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a
|
||||
number smaller than the number of arguments from `args_iter`.
|
||||
TODO: Consider changing this behaviour.
|
||||
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||
|
||||
Args:
|
||||
msg (optional):
|
||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||
"""
|
||||
await self._all_tasks_known_flag.wait()
|
||||
log.warning("%s cancelling all tasks!", str(self))
|
||||
self._interrupt_flag.set()
|
||||
for task in self._running.values():
|
||||
task.cancel(msg=msg)
|
||||
|
||||
async def flush(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the task pool, returns their results, and forgets them.
|
||||
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
|
||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
|
||||
@ -300,6 +303,7 @@ class BaseTaskPool:
|
||||
"""
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = {}
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
def close(self) -> None:
|
||||
@ -309,16 +313,18 @@ class BaseTaskPool:
|
||||
|
||||
async def gather(self, return_exceptions: bool = False):
|
||||
"""
|
||||
Calls `asyncio.gather` on all tasks from the task pool, returns their results, and forgets them.
|
||||
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
|
||||
|
||||
The `close()` method must have been called prior to this.
|
||||
|
||||
This method blocks, if a currently unknown number of coroutine functions have been registered to be run as
|
||||
tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number
|
||||
smaller than the number of arguments from `args_iter`.
|
||||
TODO: Consider changing this behaviour.
|
||||
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||
than the number of arguments from `args_iter`.
|
||||
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||
blocking this method. Otherwise it will wait until they all have started.
|
||||
|
||||
This method also blocks, if any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||
callbacks registered for the tasks block.
|
||||
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
|
||||
callbacks registered for a task blocks.
|
||||
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
@ -332,6 +338,7 @@ class BaseTaskPool:
|
||||
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._ended = self._cancelled = self._running = {}
|
||||
self._interrupt_flag.clear()
|
||||
return results
|
||||
|
||||
|
||||
@ -369,7 +376,7 @@ class TaskPool(BaseTaskPool):
|
||||
|
||||
async def _start_next_coroutine() -> bool:
|
||||
cor = self._get_next_coroutine(func, args_iter, arg_stars)
|
||||
if cor is None:
|
||||
if cor is None or self._interrupt_flag.is_set():
|
||||
self._all_tasks_known_flag.set()
|
||||
return True
|
||||
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
|
||||
|
@ -52,6 +52,8 @@ class BaseTaskPoolTestCase(TestCase):
|
||||
self.assertEqual(self.test_pool_name, self.task_pool._name)
|
||||
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
|
||||
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
|
||||
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
|
||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
|
||||
self.mock___str__.assert_called_once_with()
|
||||
|
Loading…
x
Reference in New Issue
Block a user