diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index e40e632..0b2ec46 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -39,6 +39,7 @@ class BaseTaskPool: self._name: str = name self._all_tasks_known_flag: Event = Event() self._all_tasks_known_flag.set() + self._interrupt_flag: Event = Event() log.debug("%s initialized", str(self)) def __str__(self) -> str: @@ -188,7 +189,7 @@ class BaseTaskPool: cancel_callback: CancelCallbackT = None) -> int: """ Starts a coroutine as a new task in the pool. - This method blocks, **only if** there the pool is full. + This method blocks, **only if** the pool is full. Returns/raises whatever the wrapped coroutine does. Args: @@ -273,25 +274,27 @@ class BaseTaskPool: for task in tasks: task.cancel(msg=msg) - async def cancel_all(self, msg: str = None) -> None: + def cancel_all(self, msg: str = None) -> None: """ Cancels all tasks still running within the pool. - This method blocks, **only if** a currently unknown number of coroutine functions have been registered to be - run as tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a - number smaller than the number of arguments from `args_iter`. - TODO: Consider changing this behaviour. + + Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. + This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller + than the number of arguments from `args_iter`. + In this case, those already running will be cancelled, while the following will **never even start**. Args: msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`. """ - await self._all_tasks_known_flag.wait() + log.warning("%s cancelling all tasks!", str(self)) + self._interrupt_flag.set() for task in self._running.values(): task.cancel(msg=msg) async def flush(self, return_exceptions: bool = False): """ - Calls `asyncio.gather` on all ended/cancelled tasks from the task pool, returns their results, and forgets them. + Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks. This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks registered for the tasks block. @@ -300,6 +303,7 @@ class BaseTaskPool: """ results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) self._ended = self._cancelled = {} + self._interrupt_flag.clear() return results def close(self) -> None: @@ -309,16 +313,18 @@ class BaseTaskPool: async def gather(self, return_exceptions: bool = False): """ - Calls `asyncio.gather` on all tasks from the task pool, returns their results, and forgets them. + Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks. + The `close()` method must have been called prior to this. - This method blocks, if a currently unknown number of coroutine functions have been registered to be run as - tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number - smaller than the number of arguments from `args_iter`. - TODO: Consider changing this behaviour. + Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. + This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller + than the number of arguments from `args_iter`. + In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially + blocking this method. Otherwise it will wait until they all have started. - This method also blocks, if any of the tasks block while catching a `asyncio.CancelledError` or any of the - callbacks registered for the tasks block. + This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the + callbacks registered for a task blocks. Args: return_exceptions (optional): Passed directly into `gather`. @@ -332,6 +338,7 @@ class BaseTaskPool: results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), return_exceptions=return_exceptions) self._ended = self._cancelled = self._running = {} + self._interrupt_flag.clear() return results @@ -369,7 +376,7 @@ class TaskPool(BaseTaskPool): async def _start_next_coroutine() -> bool: cor = self._get_next_coroutine(func, args_iter, arg_stars) - if cor is None: + if cor is None or self._interrupt_flag.is_set(): self._all_tasks_known_flag.set() return True await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback) diff --git a/tests/test_pool.py b/tests/test_pool.py index 96b4641..b2340b2 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -52,6 +52,8 @@ class BaseTaskPoolTestCase(TestCase): self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertTrue(self.task_pool._all_tasks_known_flag.is_set()) + self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) + self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock___str__.assert_called_once_with()