generated from daniil-berg/boilerplate-py
new interrupt flag for cancel_all
to prevent _map()
from starting new tasks
This commit is contained in:
parent
2d40f5707b
commit
e8e13406ea
@ -39,6 +39,7 @@ class BaseTaskPool:
|
|||||||
self._name: str = name
|
self._name: str = name
|
||||||
self._all_tasks_known_flag: Event = Event()
|
self._all_tasks_known_flag: Event = Event()
|
||||||
self._all_tasks_known_flag.set()
|
self._all_tasks_known_flag.set()
|
||||||
|
self._interrupt_flag: Event = Event()
|
||||||
log.debug("%s initialized", str(self))
|
log.debug("%s initialized", str(self))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -188,7 +189,7 @@ class BaseTaskPool:
|
|||||||
cancel_callback: CancelCallbackT = None) -> int:
|
cancel_callback: CancelCallbackT = None) -> int:
|
||||||
"""
|
"""
|
||||||
Starts a coroutine as a new task in the pool.
|
Starts a coroutine as a new task in the pool.
|
||||||
This method blocks, **only if** there the pool is full.
|
This method blocks, **only if** the pool is full.
|
||||||
Returns/raises whatever the wrapped coroutine does.
|
Returns/raises whatever the wrapped coroutine does.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -273,25 +274,27 @@ class BaseTaskPool:
|
|||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel(msg=msg)
|
task.cancel(msg=msg)
|
||||||
|
|
||||||
async def cancel_all(self, msg: str = None) -> None:
|
def cancel_all(self, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels all tasks still running within the pool.
|
Cancels all tasks still running within the pool.
|
||||||
This method blocks, **only if** a currently unknown number of coroutine functions have been registered to be
|
|
||||||
run as tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a
|
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||||
number smaller than the number of arguments from `args_iter`.
|
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||||
TODO: Consider changing this behaviour.
|
than the number of arguments from `args_iter`.
|
||||||
|
In this case, those already running will be cancelled, while the following will **never even start**.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg (optional):
|
msg (optional):
|
||||||
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||||
"""
|
"""
|
||||||
await self._all_tasks_known_flag.wait()
|
log.warning("%s cancelling all tasks!", str(self))
|
||||||
|
self._interrupt_flag.set()
|
||||||
for task in self._running.values():
|
for task in self._running.values():
|
||||||
task.cancel(msg=msg)
|
task.cancel(msg=msg)
|
||||||
|
|
||||||
async def flush(self, return_exceptions: bool = False):
|
async def flush(self, return_exceptions: bool = False):
|
||||||
"""
|
"""
|
||||||
Calls `asyncio.gather` on all ended/cancelled tasks from the task pool, returns their results, and forgets them.
|
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
|
||||||
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
||||||
callbacks registered for the tasks block.
|
callbacks registered for the tasks block.
|
||||||
|
|
||||||
@ -300,6 +303,7 @@ class BaseTaskPool:
|
|||||||
"""
|
"""
|
||||||
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
|
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
|
||||||
self._ended = self._cancelled = {}
|
self._ended = self._cancelled = {}
|
||||||
|
self._interrupt_flag.clear()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
@ -309,16 +313,18 @@ class BaseTaskPool:
|
|||||||
|
|
||||||
async def gather(self, return_exceptions: bool = False):
|
async def gather(self, return_exceptions: bool = False):
|
||||||
"""
|
"""
|
||||||
Calls `asyncio.gather` on all tasks from the task pool, returns their results, and forgets them.
|
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
|
||||||
|
|
||||||
The `close()` method must have been called prior to this.
|
The `close()` method must have been called prior to this.
|
||||||
|
|
||||||
This method blocks, if a currently unknown number of coroutine functions have been registered to be run as
|
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
|
||||||
tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number
|
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
|
||||||
smaller than the number of arguments from `args_iter`.
|
than the number of arguments from `args_iter`.
|
||||||
TODO: Consider changing this behaviour.
|
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
|
||||||
|
blocking this method. Otherwise it will wait until they all have started.
|
||||||
|
|
||||||
This method also blocks, if any of the tasks block while catching a `asyncio.CancelledError` or any of the
|
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
|
||||||
callbacks registered for the tasks block.
|
callbacks registered for a task blocks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
return_exceptions (optional): Passed directly into `gather`.
|
return_exceptions (optional): Passed directly into `gather`.
|
||||||
@ -332,6 +338,7 @@ class BaseTaskPool:
|
|||||||
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
|
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
|
||||||
return_exceptions=return_exceptions)
|
return_exceptions=return_exceptions)
|
||||||
self._ended = self._cancelled = self._running = {}
|
self._ended = self._cancelled = self._running = {}
|
||||||
|
self._interrupt_flag.clear()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -369,7 +376,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
|
|
||||||
async def _start_next_coroutine() -> bool:
|
async def _start_next_coroutine() -> bool:
|
||||||
cor = self._get_next_coroutine(func, args_iter, arg_stars)
|
cor = self._get_next_coroutine(func, args_iter, arg_stars)
|
||||||
if cor is None:
|
if cor is None or self._interrupt_flag.is_set():
|
||||||
self._all_tasks_known_flag.set()
|
self._all_tasks_known_flag.set()
|
||||||
return True
|
return True
|
||||||
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
|
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
|
||||||
|
@ -52,6 +52,8 @@ class BaseTaskPoolTestCase(TestCase):
|
|||||||
self.assertEqual(self.test_pool_name, self.task_pool._name)
|
self.assertEqual(self.test_pool_name, self.task_pool._name)
|
||||||
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
|
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
|
||||||
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
|
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
|
||||||
|
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
|
||||||
|
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
||||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||||
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
|
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
|
||||||
self.mock___str__.assert_called_once_with()
|
self.mock___str__.assert_called_once_with()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user