From a9011076c49611367f508a6b72762bf0f186ac5d Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Fri, 25 Mar 2022 12:58:18 +0100 Subject: [PATCH] fixed potential race cond. gathering meta tasks --- src/asyncio_taskpool/internals/helpers.py | 6 ---- src/asyncio_taskpool/pool.py | 39 ++++++++++------------- tests/test_internals/test_helpers.py | 6 ---- tests/test_pool.py | 30 ++++++++++------- 4 files changed, 34 insertions(+), 47 deletions(-) diff --git a/src/asyncio_taskpool/internals/helpers.py b/src/asyncio_taskpool/internals/helpers.py index 06fb76b..7deb381 100644 --- a/src/asyncio_taskpool/internals/helpers.py +++ b/src/asyncio_taskpool/internals/helpers.py @@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p from asyncio.coroutines import iscoroutinefunction -from asyncio.queues import Queue from importlib import import_module from inspect import getdoc from typing import Any, Optional, Union @@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") -async def join_queue(q: Queue) -> None: - """Wrapper function around the join method of an `asyncio.Queue` instance.""" - await q.join() - - def get_first_doc_line(obj: object) -> str: """Takes an object and returns the first (non-empty) line of its docstring.""" return getdoc(obj).strip().split("\n", 1)[0].strip() diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index d44e6c4..34df593 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -42,7 +42,7 @@ from . import exceptions from .queue_context import Queue from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT from .internals.group_register import TaskGroupRegister -from .internals.helpers import execute_optional, star_function, join_queue +from .internals.helpers import execute_optional, star_function from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB @@ -82,8 +82,7 @@ class BaseTaskPool: self._tasks_cancelled: Dict[int, Task] = {} self._tasks_ended: Dict[int, Task] = {} - # These next three attributes act as synchronisation primitives necessary for managing the pool. - self._before_gathering: List[Awaitable] = [] + # These next two attributes act as synchronisation primitives necessary for managing the pool. self._enough_room: Semaphore = Semaphore() self._task_groups: Dict[str, TaskGroupRegister[int]] = {} @@ -468,13 +467,11 @@ class BaseTaskPool: """ if not self._locked: raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered") - await gather(*self._before_gathering) await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(), return_exceptions=return_exceptions) self._tasks_ended.clear() self._tasks_cancelled.clear() self._tasks_running.clear() - self._before_gathering.clear() self._closed = True @@ -499,7 +496,7 @@ class TaskPool(BaseTaskPool): def __init__(self, pool_size: int = inf, name: str = None) -> None: super().__init__(pool_size=pool_size, name=name) # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of - # meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks. + # meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks. self._group_meta_tasks_running: Dict[str, Set[Task]] = {} self._meta_tasks_cancelled: Set[Task] = set() @@ -592,11 +589,11 @@ class TaskPool(BaseTaskPool): Args: return_exceptions (optional): Passed directly into `gather`. """ - await super().flush(return_exceptions=return_exceptions) with suppress(CancelledError): await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(), return_exceptions=return_exceptions) self._meta_tasks_cancelled.clear() + await super().flush(return_exceptions=return_exceptions) async def gather_and_close(self, return_exceptions: bool = False): """ @@ -607,9 +604,8 @@ class TaskPool(BaseTaskPool): The `lock()` method must have been called prior to this. Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta - tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed, - which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid - this, make sure to call :meth:`cancel_all` prior to this. + tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is + fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first. This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the callbacks registered for a task blocks for whatever reason. @@ -620,15 +616,12 @@ class TaskPool(BaseTaskPool): Raises: `PoolStillUnlocked`: The pool has not been locked yet. """ - # TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`). - await super().gather_and_close(return_exceptions=return_exceptions) - not_cancelled_meta_tasks = set() - while self._group_meta_tasks_running: - _, meta_tasks = self._group_meta_tasks_running.popitem() - not_cancelled_meta_tasks.update(meta_tasks) + not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set) with suppress(CancelledError): await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions) self._meta_tasks_cancelled.clear() + self._group_meta_tasks_running.clear() + await super().gather_and_close(return_exceptions=return_exceptions) @staticmethod def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str: @@ -789,7 +782,7 @@ class TaskPool(BaseTaskPool): # The following line blocks **only if** the number of running tasks spawned by this method has reached the # specified maximum as determined in :meth:`_map`. await map_semaphore.acquire() - # We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called. + # We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called. async with arg_queue as next_arg: if next_arg is self._QUEUE_END_SENTINEL: # The :meth:`_queue_producer` either reached the last argument or was cancelled. @@ -797,6 +790,9 @@ class TaskPool(BaseTaskPool): try: await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) + except CancelledError: + map_semaphore.release() + return except Exception as e: # This means an exception occurred during task **creation**, meaning no task has been created. # It does not imply an error within the task itself. @@ -856,13 +852,10 @@ class TaskPool(BaseTaskPool): async with group_reg: # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the # `_queue_producer()`; that way an argument + # TODO: Perhaps this can be simplified to just one meta-task with no need for a queue. + # The limiting factor honoring the group size is already the semaphore in the queue consumer; + # Try to write this without a producer, instead consuming the `arg_iter` directly. arg_queue = Queue(maxsize=group_size) - # TODO: This is the wrong thing to await before gathering! - # Since the queue producer and consumer operate in separate tasks, it is possible that the consumer - # "finishes" the entire queue before the producer manages to put more items in it, thus returning - # the `join` call before the arguments iterator was fully consumed. - # Probably the queue producer task should be awaited before gathering instead. - self._before_gathering.append(join_queue(arg_queue)) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) # Start the producer and consumer meta tasks. meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) diff --git a/tests/test_internals/test_helpers.py b/tests/test_internals/test_helpers.py index 9ed60ad..6b66aee 100644 --- a/tests/test_internals/test_helpers.py +++ b/tests/test_internals/test_helpers.py @@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase): with self.assertRaises(ValueError): helpers.star_function(f, a, 123456789) - async def test_join_queue(self): - mock_join = AsyncMock() - mock_queue = MagicMock(join=mock_join) - self.assertIsNone(await helpers.join_queue(mock_queue)) - mock_join.assert_awaited_once_with() - def test_get_first_doc_line(self): expected_output = 'foo bar baz' mock_obj = MagicMock(__doc__=f"""{expected_output} diff --git a/tests/test_pool.py b/tests/test_pool.py index d34fc85..d2a82b5 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -93,7 +93,6 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) - self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) self.assertIsInstance(self.task_pool._enough_room, Semaphore) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) @@ -366,9 +365,8 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) async def test_gather_and_close(self): - mock_before_gather, mock_running_func = AsyncMock(), AsyncMock() + mock_running_func = AsyncMock() mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) - self.task_pool._before_gathering = before_gather = [mock_before_gather()] self.task_pool._tasks_ended = ended = {123: mock_ended_func()} self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()} self.task_pool._tasks_running = running = {789: mock_running_func()} @@ -378,19 +376,16 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(ended, self.task_pool._tasks_ended) self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled) self.assertDictEqual(running, self.task_pool._tasks_running) - self.assertListEqual(before_gather, self.task_pool._before_gathering) self.assertFalse(self.task_pool._closed) self.task_pool._locked = True self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) - mock_before_gather.assert_awaited_once_with() mock_ended_func.assert_awaited_once_with() mock_cancelled_func.assert_awaited_once_with() mock_running_func.assert_awaited_once_with() self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) - self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering) self.assertTrue(self.task_pool._closed) @@ -623,19 +618,32 @@ class TaskPoolTestCase(CommonTestCase): call(mock_func, bad, arg_stars=stars) ]) + mock_semaphore_cls.reset_mock() + mock__get_map_end_callback.reset_mock() + mock__start_task.reset_mock() + mock_star_function.reset_mock() + + # With a CancelledError thrown while starting a task: + mock_semaphore_cls.return_value = semaphore = Semaphore(1) + mock_star_function.side_effect = CancelledError() + mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize) + self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) + self.assertFalse(semaphore.locked()) + mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) + mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) + mock__start_task.assert_not_called() + mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) + @patch.object(pool, 'create_task') @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) - @patch.object(pool, 'join_queue', new_callable=MagicMock) @patch.object(pool, 'Queue') @patch.object(pool, 'TaskGroupRegister') @patch.object(pool.BaseTaskPool, '_check_start') async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, - mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, - mock_create_task: MagicMock): + mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock): mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_queue_cls.return_value = mock_q = MagicMock() - mock_join_queue.return_value = fake_join = object() mock__queue_producer.return_value = fake_producer = object() mock__queue_consumer.return_value = fake_consumer = object() fake_task1, fake_task2 = object(), object() @@ -669,8 +677,6 @@ class TaskPoolTestCase(CommonTestCase): self.task_pool._task_groups[group_name] = mock_group_reg mock_group_reg.__aenter__.assert_awaited_once_with() mock_queue_cls.assert_called_once_with(maxsize=group_size) - mock_join_queue.assert_called_once_with(mock_q) - self.assertListEqual([fake_join], self.task_pool._before_gathering) mock__queue_producer.assert_called_once() mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb) mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])