diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index c7f4993..232d058 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -150,22 +150,26 @@ class BaseTaskPool: raise e return task_id - def _cancel_one(self, task_id: int, msg: str = None) -> None: + def _get_running_task(self, task_id: int) -> Task: try: - task = self._running[task_id] + return self._running[task_id] except KeyError: if self._cancelled.get(task_id): raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled") if self._ended.get(task_id): raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") - task.cancel(msg=msg) + + def _cancel_task(self, task_id: int, msg: str = None) -> None: + self._get_running_task(task_id).cancel(msg=msg) def cancel(self, *task_ids: int, msg: str = None) -> None: - for task_id in task_ids: - self._cancel_one(task_id, msg=msg) + tasks = [self._get_running_task(task_id) for task_id in task_ids] + for task in tasks: + task.cancel(msg=msg) - def cancel_all(self, msg: str = None) -> None: + async def cancel_all(self, msg: str = None) -> None: + await self._all_tasks_known_flag.wait() for task in self._running.values(): task.cancel(msg=msg)