fixed potential race cond. gathering meta tasks

This commit is contained in:
Daniil Fajnberg 2022-03-25 12:58:18 +01:00
parent 7e34aa106d
commit a9011076c4
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
4 changed files with 34 additions and 47 deletions

View File

@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
await q.join()
def get_first_doc_line(obj: object) -> str: def get_first_doc_line(obj: object) -> str:
"""Takes an object and returns the first (non-empty) line of its docstring.""" """Takes an object and returns the first (non-empty) line of its docstring."""
return getdoc(obj).strip().split("\n", 1)[0].strip() return getdoc(obj).strip().split("\n", 1)[0].strip()

View File

@ -42,7 +42,7 @@ from . import exceptions
from .queue_context import Queue from .queue_context import Queue
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function, join_queue from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -82,8 +82,7 @@ class BaseTaskPool:
self._tasks_cancelled: Dict[int, Task] = {} self._tasks_cancelled: Dict[int, Task] = {}
self._tasks_ended: Dict[int, Task] = {} self._tasks_ended: Dict[int, Task] = {}
# These next three attributes act as synchronisation primitives necessary for managing the pool. # These next two attributes act as synchronisation primitives necessary for managing the pool.
self._before_gathering: List[Awaitable] = []
self._enough_room: Semaphore = Semaphore() self._enough_room: Semaphore = Semaphore()
self._task_groups: Dict[str, TaskGroupRegister[int]] = {} self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
@ -468,13 +467,11 @@ class BaseTaskPool:
""" """
if not self._locked: if not self._locked:
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered") raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering)
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(), await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._tasks_ended.clear() self._tasks_ended.clear()
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._before_gathering.clear()
self._closed = True self._closed = True
@ -499,7 +496,7 @@ class TaskPool(BaseTaskPool):
def __init__(self, pool_size: int = inf, name: str = None) -> None: def __init__(self, pool_size: int = inf, name: str = None) -> None:
super().__init__(pool_size=pool_size, name=name) super().__init__(pool_size=pool_size, name=name)
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks. # meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
self._group_meta_tasks_running: Dict[str, Set[Task]] = {} self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
self._meta_tasks_cancelled: Set[Task] = set() self._meta_tasks_cancelled: Set[Task] = set()
@ -592,11 +589,11 @@ class TaskPool(BaseTaskPool):
Args: Args:
return_exceptions (optional): Passed directly into `gather`. return_exceptions (optional): Passed directly into `gather`.
""" """
await super().flush(return_exceptions=return_exceptions)
with suppress(CancelledError): with suppress(CancelledError):
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(), await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._meta_tasks_cancelled.clear() self._meta_tasks_cancelled.clear()
await super().flush(return_exceptions=return_exceptions)
async def gather_and_close(self, return_exceptions: bool = False): async def gather_and_close(self, return_exceptions: bool = False):
""" """
@ -607,9 +604,8 @@ class TaskPool(BaseTaskPool):
The `lock()` method must have been called prior to this. The `lock()` method must have been called prior to this.
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed, tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
this, make sure to call :meth:`cancel_all` prior to this.
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
the callbacks registered for a task blocks for whatever reason. the callbacks registered for a task blocks for whatever reason.
@ -620,15 +616,12 @@ class TaskPool(BaseTaskPool):
Raises: Raises:
`PoolStillUnlocked`: The pool has not been locked yet. `PoolStillUnlocked`: The pool has not been locked yet.
""" """
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`). not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
await super().gather_and_close(return_exceptions=return_exceptions)
not_cancelled_meta_tasks = set()
while self._group_meta_tasks_running:
_, meta_tasks = self._group_meta_tasks_running.popitem()
not_cancelled_meta_tasks.update(meta_tasks)
with suppress(CancelledError): with suppress(CancelledError):
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions) await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
self._meta_tasks_cancelled.clear() self._meta_tasks_cancelled.clear()
self._group_meta_tasks_running.clear()
await super().gather_and_close(return_exceptions=return_exceptions)
@staticmethod @staticmethod
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str: def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
@ -789,7 +782,7 @@ class TaskPool(BaseTaskPool):
# The following line blocks **only if** the number of running tasks spawned by this method has reached the # The following line blocks **only if** the number of running tasks spawned by this method has reached the
# specified maximum as determined in :meth:`_map`. # specified maximum as determined in :meth:`_map`.
await map_semaphore.acquire() await map_semaphore.acquire()
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called. # We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called.
async with arg_queue as next_arg: async with arg_queue as next_arg:
if next_arg is self._QUEUE_END_SENTINEL: if next_arg is self._QUEUE_END_SENTINEL:
# The :meth:`_queue_producer` either reached the last argument or was cancelled. # The :meth:`_queue_producer` either reached the last argument or was cancelled.
@ -797,6 +790,9 @@ class TaskPool(BaseTaskPool):
try: try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
map_semaphore.release()
return
except Exception as e: except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created. # This means an exception occurred during task **creation**, meaning no task has been created.
# It does not imply an error within the task itself. # It does not imply an error within the task itself.
@ -856,13 +852,10 @@ class TaskPool(BaseTaskPool):
async with group_reg: async with group_reg:
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
# `_queue_producer()`; that way an argument # `_queue_producer()`; that way an argument
# TODO: Perhaps this can be simplified to just one meta-task with no need for a queue.
# The limiting factor honoring the group size is already the semaphore in the queue consumer;
# Try to write this without a producer, instead consuming the `arg_iter` directly.
arg_queue = Queue(maxsize=group_size) arg_queue = Queue(maxsize=group_size)
# TODO: This is the wrong thing to await before gathering!
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
# the `join` call before the arguments iterator was fully consumed.
# Probably the queue producer task should be awaited before gathering instead.
self._before_gathering.append(join_queue(arg_queue))
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
# Start the producer and consumer meta tasks. # Start the producer and consumer meta tasks.
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))

View File

@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
helpers.star_function(f, a, 123456789) helpers.star_function(f, a, 123456789)
async def test_join_queue(self):
mock_join = AsyncMock()
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_get_first_doc_line(self): def test_get_first_doc_line(self):
expected_output = 'foo bar baz' expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output} mock_obj = MagicMock(__doc__=f"""{expected_output}

View File

@ -93,7 +93,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._enough_room, Semaphore) self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
@ -366,9 +365,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
async def test_gather_and_close(self): async def test_gather_and_close(self):
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock() mock_running_func = AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
self.task_pool._tasks_ended = ended = {123: mock_ended_func()} self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()} self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()} self.task_pool._tasks_running = running = {789: mock_running_func()}
@ -378,19 +376,16 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(ended, self.task_pool._tasks_ended) self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled) self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(running, self.task_pool._tasks_running) self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertFalse(self.task_pool._closed) self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_before_gather.assert_awaited_once_with()
mock_ended_func.assert_awaited_once_with() mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with() mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with() mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._closed) self.assertTrue(self.task_pool._closed)
@ -623,19 +618,32 @@ class TaskPoolTestCase(CommonTestCase):
call(mock_func, bad, arg_stars=stars) call(mock_func, bad, arg_stars=stars)
]) ])
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock()
# With a CancelledError thrown while starting a task:
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
mock_star_function.side_effect = CancelledError()
mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize)
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
self.assertFalse(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called()
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool, 'Queue') @patch.object(pool, 'Queue')
@patch.object(pool, 'TaskGroupRegister') @patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start') @patch.object(pool.BaseTaskPool, '_check_start')
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock):
mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_queue_cls.return_value = mock_q = MagicMock() mock_queue_cls.return_value = mock_q = MagicMock()
mock_join_queue.return_value = fake_join = object()
mock__queue_producer.return_value = fake_producer = object() mock__queue_producer.return_value = fake_producer = object()
mock__queue_consumer.return_value = fake_consumer = object() mock__queue_consumer.return_value = fake_consumer = object()
fake_task1, fake_task2 = object(), object() fake_task1, fake_task2 = object(), object()
@ -669,8 +677,6 @@ class TaskPoolTestCase(CommonTestCase):
self.task_pool._task_groups[group_name] = mock_group_reg self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with() mock_group_reg.__aenter__.assert_awaited_once_with()
mock_queue_cls.assert_called_once_with(maxsize=group_size) mock_queue_cls.assert_called_once_with(maxsize=group_size)
mock_join_queue.assert_called_once_with(mock_q)
self.assertListEqual([fake_join], self.task_pool._before_gathering)
mock__queue_producer.assert_called_once() mock__queue_producer.assert_called_once()
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb) mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)]) mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])