generated from daniil-berg/boilerplate-py
fixed potential race cond. gathering meta tasks
This commit is contained in:
parent
7e34aa106d
commit
a9011076c4
@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
||||
|
||||
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.queues import Queue
|
||||
from importlib import import_module
|
||||
from inspect import getdoc
|
||||
from typing import Any, Optional, Union
|
||||
@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
||||
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
|
||||
|
||||
|
||||
async def join_queue(q: Queue) -> None:
|
||||
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
|
||||
await q.join()
|
||||
|
||||
|
||||
def get_first_doc_line(obj: object) -> str:
|
||||
"""Takes an object and returns the first (non-empty) line of its docstring."""
|
||||
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||
|
@ -42,7 +42,7 @@ from . import exceptions
|
||||
from .queue_context import Queue
|
||||
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
|
||||
from .internals.group_register import TaskGroupRegister
|
||||
from .internals.helpers import execute_optional, star_function, join_queue
|
||||
from .internals.helpers import execute_optional, star_function
|
||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||
|
||||
|
||||
@ -82,8 +82,7 @@ class BaseTaskPool:
|
||||
self._tasks_cancelled: Dict[int, Task] = {}
|
||||
self._tasks_ended: Dict[int, Task] = {}
|
||||
|
||||
# These next three attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._before_gathering: List[Awaitable] = []
|
||||
# These next two attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._enough_room: Semaphore = Semaphore()
|
||||
self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
|
||||
|
||||
@ -468,13 +467,11 @@ class BaseTaskPool:
|
||||
"""
|
||||
if not self._locked:
|
||||
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
|
||||
await gather(*self._before_gathering)
|
||||
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._tasks_ended.clear()
|
||||
self._tasks_cancelled.clear()
|
||||
self._tasks_running.clear()
|
||||
self._before_gathering.clear()
|
||||
self._closed = True
|
||||
|
||||
|
||||
@ -499,7 +496,7 @@ class TaskPool(BaseTaskPool):
|
||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
||||
super().__init__(pool_size=pool_size, name=name)
|
||||
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
|
||||
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
|
||||
# meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
|
||||
self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
|
||||
self._meta_tasks_cancelled: Set[Task] = set()
|
||||
|
||||
@ -592,11 +589,11 @@ class TaskPool(BaseTaskPool):
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
"""
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
|
||||
async def gather_and_close(self, return_exceptions: bool = False):
|
||||
"""
|
||||
@ -607,9 +604,8 @@ class TaskPool(BaseTaskPool):
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
|
||||
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed,
|
||||
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
|
||||
this, make sure to call :meth:`cancel_all` prior to this.
|
||||
tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
|
||||
fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
|
||||
|
||||
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
||||
the callbacks registered for a task blocks for whatever reason.
|
||||
@ -620,15 +616,12 @@ class TaskPool(BaseTaskPool):
|
||||
Raises:
|
||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
||||
"""
|
||||
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`).
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
not_cancelled_meta_tasks = set()
|
||||
while self._group_meta_tasks_running:
|
||||
_, meta_tasks = self._group_meta_tasks_running.popitem()
|
||||
not_cancelled_meta_tasks.update(meta_tasks)
|
||||
not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
self._group_meta_tasks_running.clear()
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
|
||||
@ -789,7 +782,7 @@ class TaskPool(BaseTaskPool):
|
||||
# The following line blocks **only if** the number of running tasks spawned by this method has reached the
|
||||
# specified maximum as determined in :meth:`_map`.
|
||||
await map_semaphore.acquire()
|
||||
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
|
||||
# We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called.
|
||||
async with arg_queue as next_arg:
|
||||
if next_arg is self._QUEUE_END_SENTINEL:
|
||||
# The :meth:`_queue_producer` either reached the last argument or was cancelled.
|
||||
@ -797,6 +790,9 @@ class TaskPool(BaseTaskPool):
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
map_semaphore.release()
|
||||
return
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
@ -856,13 +852,10 @@ class TaskPool(BaseTaskPool):
|
||||
async with group_reg:
|
||||
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
|
||||
# `_queue_producer()`; that way an argument
|
||||
# TODO: Perhaps this can be simplified to just one meta-task with no need for a queue.
|
||||
# The limiting factor honoring the group size is already the semaphore in the queue consumer;
|
||||
# Try to write this without a producer, instead consuming the `arg_iter` directly.
|
||||
arg_queue = Queue(maxsize=group_size)
|
||||
# TODO: This is the wrong thing to await before gathering!
|
||||
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
|
||||
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
|
||||
# the `join` call before the arguments iterator was fully consumed.
|
||||
# Probably the queue producer task should be awaited before gathering instead.
|
||||
self._before_gathering.append(join_queue(arg_queue))
|
||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||
# Start the producer and consumer meta tasks.
|
||||
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
|
||||
|
@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
helpers.star_function(f, a, 123456789)
|
||||
|
||||
async def test_join_queue(self):
|
||||
mock_join = AsyncMock()
|
||||
mock_queue = MagicMock(join=mock_join)
|
||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
||||
mock_join.assert_awaited_once_with()
|
||||
|
||||
def test_get_first_doc_line(self):
|
||||
expected_output = 'foo bar baz'
|
||||
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
||||
|
@ -93,7 +93,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||
|
||||
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
||||
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||
|
||||
@ -366,9 +365,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
|
||||
async def test_gather_and_close(self):
|
||||
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
|
||||
mock_running_func = AsyncMock()
|
||||
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
|
||||
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
|
||||
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
|
||||
self.task_pool._tasks_running = running = {789: mock_running_func()}
|
||||
@ -378,19 +376,16 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
self.assertDictEqual(ended, self.task_pool._tasks_ended)
|
||||
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(running, self.task_pool._tasks_running)
|
||||
self.assertListEqual(before_gather, self.task_pool._before_gathering)
|
||||
self.assertFalse(self.task_pool._closed)
|
||||
|
||||
self.task_pool._locked = True
|
||||
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||
mock_before_gather.assert_awaited_once_with()
|
||||
mock_ended_func.assert_awaited_once_with()
|
||||
mock_cancelled_func.assert_awaited_once_with()
|
||||
mock_running_func.assert_awaited_once_with()
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
|
||||
self.assertTrue(self.task_pool._closed)
|
||||
|
||||
|
||||
@ -623,19 +618,32 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
call(mock_func, bad, arg_stars=stars)
|
||||
])
|
||||
|
||||
mock_semaphore_cls.reset_mock()
|
||||
mock__get_map_end_callback.reset_mock()
|
||||
mock__start_task.reset_mock()
|
||||
mock_star_function.reset_mock()
|
||||
|
||||
# With a CancelledError thrown while starting a task:
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
|
||||
mock_star_function.side_effect = CancelledError()
|
||||
mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize)
|
||||
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
|
||||
self.assertFalse(semaphore.locked())
|
||||
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_not_called()
|
||||
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
|
||||
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
|
||||
@patch.object(pool, 'join_queue', new_callable=MagicMock)
|
||||
@patch.object(pool, 'Queue')
|
||||
@patch.object(pool, 'TaskGroupRegister')
|
||||
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
|
||||
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
|
||||
mock_create_task: MagicMock):
|
||||
mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock):
|
||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||
mock_queue_cls.return_value = mock_q = MagicMock()
|
||||
mock_join_queue.return_value = fake_join = object()
|
||||
mock__queue_producer.return_value = fake_producer = object()
|
||||
mock__queue_consumer.return_value = fake_consumer = object()
|
||||
fake_task1, fake_task2 = object(), object()
|
||||
@ -669,8 +677,6 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
self.task_pool._task_groups[group_name] = mock_group_reg
|
||||
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||
mock_queue_cls.assert_called_once_with(maxsize=group_size)
|
||||
mock_join_queue.assert_called_once_with(mock_q)
|
||||
self.assertListEqual([fake_join], self.task_pool._before_gathering)
|
||||
mock__queue_producer.assert_called_once()
|
||||
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
|
||||
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
|
||||
|
Loading…
Reference in New Issue
Block a user