generated from daniil-berg/boilerplate-py
	fixed potential race cond. gathering meta tasks
This commit is contained in:
		@@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from asyncio.coroutines import iscoroutinefunction
 | 
			
		||||
from asyncio.queues import Queue
 | 
			
		||||
from importlib import import_module
 | 
			
		||||
from inspect import getdoc
 | 
			
		||||
from typing import Any, Optional, Union
 | 
			
		||||
@@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
 | 
			
		||||
    raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def join_queue(q: Queue) -> None:
 | 
			
		||||
    """Wrapper function around the join method of an `asyncio.Queue` instance."""
 | 
			
		||||
    await q.join()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_first_doc_line(obj: object) -> str:
 | 
			
		||||
    """Takes an object and returns the first (non-empty) line of its docstring."""
 | 
			
		||||
    return getdoc(obj).strip().split("\n", 1)[0].strip()
 | 
			
		||||
 
 | 
			
		||||
@@ -42,7 +42,7 @@ from . import exceptions
 | 
			
		||||
from .queue_context import Queue
 | 
			
		||||
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
 | 
			
		||||
from .internals.group_register import TaskGroupRegister
 | 
			
		||||
from .internals.helpers import execute_optional, star_function, join_queue
 | 
			
		||||
from .internals.helpers import execute_optional, star_function
 | 
			
		||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -82,8 +82,7 @@ class BaseTaskPool:
 | 
			
		||||
        self._tasks_cancelled: Dict[int, Task] = {}
 | 
			
		||||
        self._tasks_ended: Dict[int, Task] = {}
 | 
			
		||||
 | 
			
		||||
        # These next three attributes act as synchronisation primitives necessary for managing the pool.
 | 
			
		||||
        self._before_gathering: List[Awaitable] = []
 | 
			
		||||
        # These next two attributes act as synchronisation primitives necessary for managing the pool.
 | 
			
		||||
        self._enough_room: Semaphore = Semaphore()
 | 
			
		||||
        self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
 | 
			
		||||
 | 
			
		||||
@@ -468,13 +467,11 @@ class BaseTaskPool:
 | 
			
		||||
        """
 | 
			
		||||
        if not self._locked:
 | 
			
		||||
            raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
 | 
			
		||||
        await gather(*self._before_gathering)
 | 
			
		||||
        await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
 | 
			
		||||
                     return_exceptions=return_exceptions)
 | 
			
		||||
        self._tasks_ended.clear()
 | 
			
		||||
        self._tasks_cancelled.clear()
 | 
			
		||||
        self._tasks_running.clear()
 | 
			
		||||
        self._before_gathering.clear()
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -499,7 +496,7 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
    def __init__(self, pool_size: int = inf, name: str = None) -> None:
 | 
			
		||||
        super().__init__(pool_size=pool_size, name=name)
 | 
			
		||||
        # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
 | 
			
		||||
        # meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
 | 
			
		||||
        # meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
 | 
			
		||||
        self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
 | 
			
		||||
        self._meta_tasks_cancelled: Set[Task] = set()
 | 
			
		||||
 | 
			
		||||
@@ -592,11 +589,11 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
        Args:
 | 
			
		||||
            return_exceptions (optional): Passed directly into `gather`.
 | 
			
		||||
        """
 | 
			
		||||
        await super().flush(return_exceptions=return_exceptions)
 | 
			
		||||
        with suppress(CancelledError):
 | 
			
		||||
            await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
 | 
			
		||||
                         return_exceptions=return_exceptions)
 | 
			
		||||
        self._meta_tasks_cancelled.clear()
 | 
			
		||||
        await super().flush(return_exceptions=return_exceptions)
 | 
			
		||||
 | 
			
		||||
    async def gather_and_close(self, return_exceptions: bool = False):
 | 
			
		||||
        """
 | 
			
		||||
@@ -607,9 +604,8 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
        The `lock()` method must have been called prior to this.
 | 
			
		||||
 | 
			
		||||
        Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
 | 
			
		||||
        tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed,
 | 
			
		||||
        which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
 | 
			
		||||
        this, make sure to call :meth:`cancel_all` prior to this.
 | 
			
		||||
        tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
 | 
			
		||||
        fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
 | 
			
		||||
 | 
			
		||||
        This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
 | 
			
		||||
        the callbacks registered for a task blocks for whatever reason.
 | 
			
		||||
@@ -620,15 +616,12 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
        Raises:
 | 
			
		||||
            `PoolStillUnlocked`: The pool has not been locked yet.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`).
 | 
			
		||||
        await super().gather_and_close(return_exceptions=return_exceptions)
 | 
			
		||||
        not_cancelled_meta_tasks = set()
 | 
			
		||||
        while self._group_meta_tasks_running:
 | 
			
		||||
            _, meta_tasks = self._group_meta_tasks_running.popitem()
 | 
			
		||||
            not_cancelled_meta_tasks.update(meta_tasks)
 | 
			
		||||
        not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
 | 
			
		||||
        with suppress(CancelledError):
 | 
			
		||||
            await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
 | 
			
		||||
        self._meta_tasks_cancelled.clear()
 | 
			
		||||
        self._group_meta_tasks_running.clear()
 | 
			
		||||
        await super().gather_and_close(return_exceptions=return_exceptions)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
 | 
			
		||||
@@ -789,7 +782,7 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
            # The following line blocks **only if** the number of running tasks spawned by this method has reached the
 | 
			
		||||
            # specified maximum as determined in :meth:`_map`.
 | 
			
		||||
            await map_semaphore.acquire()
 | 
			
		||||
            # We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
 | 
			
		||||
            # We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called.
 | 
			
		||||
            async with arg_queue as next_arg:
 | 
			
		||||
                if next_arg is self._QUEUE_END_SENTINEL:
 | 
			
		||||
                    # The :meth:`_queue_producer` either reached the last argument or was cancelled.
 | 
			
		||||
@@ -797,6 +790,9 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
                try:
 | 
			
		||||
                    await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
 | 
			
		||||
                                           ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
 | 
			
		||||
                except CancelledError:
 | 
			
		||||
                    map_semaphore.release()
 | 
			
		||||
                    return
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    # This means an exception occurred during task **creation**, meaning no task has been created.
 | 
			
		||||
                    # It does not imply an error within the task itself.
 | 
			
		||||
@@ -856,13 +852,10 @@ class TaskPool(BaseTaskPool):
 | 
			
		||||
        async with group_reg:
 | 
			
		||||
            # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
 | 
			
		||||
            # `_queue_producer()`; that way an argument
 | 
			
		||||
            # TODO: Perhaps this can be simplified to just one meta-task with no need for a queue.
 | 
			
		||||
            #       The limiting factor honoring the group size is already the semaphore in the queue consumer;
 | 
			
		||||
            #       Try to write this without a producer, instead consuming the `arg_iter` directly.
 | 
			
		||||
            arg_queue = Queue(maxsize=group_size)
 | 
			
		||||
            # TODO: This is the wrong thing to await before gathering!
 | 
			
		||||
            #       Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
 | 
			
		||||
            #       "finishes" the entire queue before the producer manages to put more items in it, thus returning
 | 
			
		||||
            #       the `join` call before the arguments iterator was fully consumed.
 | 
			
		||||
            #       Probably the queue producer task should be awaited before gathering instead.
 | 
			
		||||
            self._before_gathering.append(join_queue(arg_queue))
 | 
			
		||||
            meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
 | 
			
		||||
            # Start the producer and consumer meta tasks.
 | 
			
		||||
            meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
 | 
			
		||||
 
 | 
			
		||||
@@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            helpers.star_function(f, a, 123456789)
 | 
			
		||||
 | 
			
		||||
    async def test_join_queue(self):
 | 
			
		||||
        mock_join = AsyncMock()
 | 
			
		||||
        mock_queue = MagicMock(join=mock_join)
 | 
			
		||||
        self.assertIsNone(await helpers.join_queue(mock_queue))
 | 
			
		||||
        mock_join.assert_awaited_once_with()
 | 
			
		||||
 | 
			
		||||
    def test_get_first_doc_line(self):
 | 
			
		||||
        expected_output = 'foo bar baz'
 | 
			
		||||
        mock_obj = MagicMock(__doc__=f"""{expected_output} 
 | 
			
		||||
 
 | 
			
		||||
@@ -93,7 +93,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
 | 
			
		||||
 | 
			
		||||
        self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
 | 
			
		||||
        self.assertIsInstance(self.task_pool._enough_room, Semaphore)
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
 | 
			
		||||
 | 
			
		||||
@@ -366,9 +365,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
 | 
			
		||||
 | 
			
		||||
    async def test_gather_and_close(self):
 | 
			
		||||
        mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
 | 
			
		||||
        mock_running_func = AsyncMock()
 | 
			
		||||
        mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
 | 
			
		||||
        self.task_pool._before_gathering = before_gather = [mock_before_gather()]
 | 
			
		||||
        self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
 | 
			
		||||
        self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
 | 
			
		||||
        self.task_pool._tasks_running = running = {789: mock_running_func()}
 | 
			
		||||
@@ -378,19 +376,16 @@ class BaseTaskPoolTestCase(CommonTestCase):
 | 
			
		||||
        self.assertDictEqual(ended, self.task_pool._tasks_ended)
 | 
			
		||||
        self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
 | 
			
		||||
        self.assertDictEqual(running, self.task_pool._tasks_running)
 | 
			
		||||
        self.assertListEqual(before_gather, self.task_pool._before_gathering)
 | 
			
		||||
        self.assertFalse(self.task_pool._closed)
 | 
			
		||||
 | 
			
		||||
        self.task_pool._locked = True
 | 
			
		||||
        self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
 | 
			
		||||
        mock_before_gather.assert_awaited_once_with()
 | 
			
		||||
        mock_ended_func.assert_awaited_once_with()
 | 
			
		||||
        mock_cancelled_func.assert_awaited_once_with()
 | 
			
		||||
        mock_running_func.assert_awaited_once_with()
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
 | 
			
		||||
        self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
 | 
			
		||||
        self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
 | 
			
		||||
        self.assertTrue(self.task_pool._closed)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -623,19 +618,32 @@ class TaskPoolTestCase(CommonTestCase):
 | 
			
		||||
            call(mock_func, bad, arg_stars=stars)
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        mock_semaphore_cls.reset_mock()
 | 
			
		||||
        mock__get_map_end_callback.reset_mock()
 | 
			
		||||
        mock__start_task.reset_mock()
 | 
			
		||||
        mock_star_function.reset_mock()
 | 
			
		||||
 | 
			
		||||
        # With a CancelledError thrown while starting a task:
 | 
			
		||||
        mock_semaphore_cls.return_value = semaphore = Semaphore(1)
 | 
			
		||||
        mock_star_function.side_effect = CancelledError()
 | 
			
		||||
        mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize)
 | 
			
		||||
        self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
 | 
			
		||||
        self.assertFalse(semaphore.locked())
 | 
			
		||||
        mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
 | 
			
		||||
        mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
 | 
			
		||||
        mock__start_task.assert_not_called()
 | 
			
		||||
        mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
 | 
			
		||||
 | 
			
		||||
    @patch.object(pool, 'create_task')
 | 
			
		||||
    @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
 | 
			
		||||
    @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
 | 
			
		||||
    @patch.object(pool, 'join_queue', new_callable=MagicMock)
 | 
			
		||||
    @patch.object(pool, 'Queue')
 | 
			
		||||
    @patch.object(pool, 'TaskGroupRegister')
 | 
			
		||||
    @patch.object(pool.BaseTaskPool, '_check_start')
 | 
			
		||||
    async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
 | 
			
		||||
                        mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
 | 
			
		||||
                        mock_create_task: MagicMock):
 | 
			
		||||
                        mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock):
 | 
			
		||||
        mock_group_reg = set_up_mock_group_register(mock_reg_cls)
 | 
			
		||||
        mock_queue_cls.return_value = mock_q = MagicMock()
 | 
			
		||||
        mock_join_queue.return_value = fake_join = object()
 | 
			
		||||
        mock__queue_producer.return_value = fake_producer = object()
 | 
			
		||||
        mock__queue_consumer.return_value = fake_consumer = object()
 | 
			
		||||
        fake_task1, fake_task2 = object(), object()
 | 
			
		||||
@@ -669,8 +677,6 @@ class TaskPoolTestCase(CommonTestCase):
 | 
			
		||||
        self.task_pool._task_groups[group_name] = mock_group_reg
 | 
			
		||||
        mock_group_reg.__aenter__.assert_awaited_once_with()
 | 
			
		||||
        mock_queue_cls.assert_called_once_with(maxsize=group_size)
 | 
			
		||||
        mock_join_queue.assert_called_once_with(mock_q)
 | 
			
		||||
        self.assertListEqual([fake_join], self.task_pool._before_gathering)
 | 
			
		||||
        mock__queue_producer.assert_called_once()
 | 
			
		||||
        mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
 | 
			
		||||
        mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user