generated from daniil-berg/boilerplate-py
	drastically simplified meta-task internals
This commit is contained in:
		| @@ -31,15 +31,13 @@ import logging | ||||
| from asyncio.coroutines import iscoroutine, iscoroutinefunction | ||||
| from asyncio.exceptions import CancelledError | ||||
| from asyncio.locks import Semaphore | ||||
| from asyncio.queues import QueueEmpty | ||||
| from asyncio.tasks import Task, create_task, gather | ||||
| from contextlib import suppress | ||||
| from datetime import datetime | ||||
| from math import inf | ||||
| from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set, Union | ||||
| from typing import Any, Awaitable, Dict, Iterable, List, Set, Union | ||||
|  | ||||
| from . import exceptions | ||||
| from .queue_context import Queue | ||||
| from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT | ||||
| from .internals.group_register import TaskGroupRegister | ||||
| from .internals.helpers import execute_optional, star_function | ||||
| @@ -491,8 +489,6 @@ class TaskPool(BaseTaskPool): | ||||
|     Adding tasks blocks **only if** the pool is full at that moment. | ||||
|     """ | ||||
|  | ||||
|     _QUEUE_END_SENTINEL = object() | ||||
|  | ||||
|     def __init__(self, pool_size: int = inf, name: str = None) -> None: | ||||
|         super().__init__(pool_size=pool_size, name=name) | ||||
|         # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of | ||||
| @@ -714,34 +710,6 @@ class TaskPool(BaseTaskPool): | ||||
|         await task | ||||
|         return group_name | ||||
|  | ||||
|     @classmethod | ||||
|     async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None: | ||||
|         """ | ||||
|         Keeps the arguments queue from :meth:`_map` full as long as the iterator has elements. | ||||
|  | ||||
|         Intended to be run as a meta task of a specific group. | ||||
|  | ||||
|         Args: | ||||
|             arg_queue: The queue of function arguments to consume for starting a new task. | ||||
|             arg_iter: The iterator of function arguments to put into the queue. | ||||
|             group_name: Name of the task group associated with this producer. | ||||
|         """ | ||||
|         try: | ||||
|             for arg in arg_iter: | ||||
|                 await arg_queue.put(arg)  # This blocks as long as the queue is full. | ||||
|         except CancelledError: | ||||
|             # This means that no more tasks are supposed to be created from this `_map()` call; | ||||
|             # thus, we can immediately drain the entire queue and forget about the rest of the arguments. | ||||
|             log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) | ||||
|             while True: | ||||
|                 try: | ||||
|                     arg_queue.get_nowait() | ||||
|                     arg_queue.item_processed() | ||||
|                 except QueueEmpty: | ||||
|                     return | ||||
|         finally: | ||||
|             await arg_queue.put(cls._QUEUE_END_SENTINEL) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: | ||||
|         """Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`.""" | ||||
| @@ -750,23 +718,25 @@ class TaskPool(BaseTaskPool): | ||||
|             await execute_optional(actual_end_callback, args=(task_id,)) | ||||
|         return release_callback | ||||
|  | ||||
|     async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0, | ||||
|                               end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: | ||||
|     async def _arg_consumer(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, | ||||
|                             arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: | ||||
|         """ | ||||
|         Consumes arguments from the queue from :meth:`_map` and keeps a limited number of tasks working on them. | ||||
|         Consumes arguments from :meth:`_map` and keeps a limited number of tasks working on them. | ||||
|  | ||||
|         The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before | ||||
|         a new task can be started, and which must be released when one of these tasks ends. | ||||
|         The `group_size` acts as the limiting value of an internal semaphore, which must be acquired before a new task | ||||
|         can be started, and which must be released when one of these tasks ends. | ||||
|  | ||||
|         Intended to be run as a meta task of a specific group. | ||||
|  | ||||
|         Args: | ||||
|             arg_queue: | ||||
|                 The queue of function arguments to consume for starting a new task. | ||||
|             group_name: | ||||
|                 Name of the associated task group; passed into :meth:`_start_task`. | ||||
|             group_size: | ||||
|                 The maximum number new tasks spawned by this method to run concurrently. | ||||
|             func: | ||||
|                 The coroutine function to use for spawning the new tasks within the task pool. | ||||
|             arg_iter: | ||||
|                 The iterable of arguments; each element is to be passed into a `func` call when spawning a new task. | ||||
|             arg_stars (optional): | ||||
|                 Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2. | ||||
|             end_callback (optional): | ||||
| @@ -776,29 +746,27 @@ class TaskPool(BaseTaskPool): | ||||
|                 The callback that was specified to execute after cancellation of the task (and the next one). | ||||
|                 It is run with the task's ID as its only positional argument. | ||||
|         """ | ||||
|         map_semaphore = Semaphore(arg_queue.maxsize)  # value determined by `group_size` in :meth:`_map` | ||||
|         map_semaphore = Semaphore(group_size) | ||||
|         release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) | ||||
|         while True: | ||||
|             # The following line blocks **only if** the number of running tasks spawned by this method has reached the | ||||
|             # specified maximum as determined in :meth:`_map`. | ||||
|         for next_arg in arg_iter: | ||||
|             # When the number of running tasks spawned by this method reaches the specified maximum, | ||||
|             # this next line will block, until one of them ends and releases the semaphore. | ||||
|             await map_semaphore.acquire() | ||||
|             # We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called. | ||||
|             async with arg_queue as next_arg: | ||||
|                 if next_arg is self._QUEUE_END_SENTINEL: | ||||
|                     # The :meth:`_queue_producer` either reached the last argument or was cancelled. | ||||
|                     return | ||||
|                 try: | ||||
|                     await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, | ||||
|                                            ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) | ||||
|                 except CancelledError: | ||||
|                     map_semaphore.release() | ||||
|                     return | ||||
|                 except Exception as e: | ||||
|                     # This means an exception occurred during task **creation**, meaning no task has been created. | ||||
|                     # It does not imply an error within the task itself. | ||||
|                     log.exception("%s occurred while trying to create task: %s(%s%s)", | ||||
|                                   str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) | ||||
|                     map_semaphore.release() | ||||
|             try: | ||||
|                 await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, | ||||
|                                        ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) | ||||
|             except CancelledError: | ||||
|                 # This means that no more tasks are supposed to be created from this `arg_iter`; | ||||
|                 # thus, we can forget about the rest of the arguments. | ||||
|                 log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) | ||||
|                 map_semaphore.release() | ||||
|                 return | ||||
|             except Exception as e: | ||||
|                 # This means an exception occurred during task **creation**, meaning no task has been created. | ||||
|                 # It does not imply an error within the task itself. | ||||
|                 log.exception("%s occurred while trying to create task: %s(%s%s)", | ||||
|                               str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) | ||||
|                 map_semaphore.release() | ||||
|  | ||||
|     async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, | ||||
|                    end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: | ||||
| @@ -815,11 +783,9 @@ class TaskPool(BaseTaskPool): | ||||
|         of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running | ||||
|         concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). | ||||
|  | ||||
|         This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. | ||||
|         Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the | ||||
|         aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does | ||||
|         not mean that any task was started or that any number of tasks will start soon, as this is solely determined by | ||||
|         the :attr:`BaseTaskPool.pool_size` and the `group_size`. | ||||
|         Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just | ||||
|         because this method returns immediately, this does not mean that any task was started or that any number of | ||||
|         tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and the `group_size`. | ||||
|  | ||||
|         Args: | ||||
|             group_name: | ||||
| @@ -850,17 +816,9 @@ class TaskPool(BaseTaskPool): | ||||
|             raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") | ||||
|         self._task_groups[group_name] = group_reg = TaskGroupRegister() | ||||
|         async with group_reg: | ||||
|             # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the | ||||
|             # `_queue_producer()`; that way an argument | ||||
|             # TODO: Perhaps this can be simplified to just one meta-task with no need for a queue. | ||||
|             #       The limiting factor honoring the group size is already the semaphore in the queue consumer; | ||||
|             #       Try to write this without a producer, instead consuming the `arg_iter` directly. | ||||
|             arg_queue = Queue(maxsize=group_size) | ||||
|             meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) | ||||
|             # Start the producer and consumer meta tasks. | ||||
|             meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) | ||||
|             meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars, | ||||
|                                                             end_callback, cancel_callback))) | ||||
|             meta_tasks.add(create_task(self._arg_consumer(group_name, group_size, func, arg_iter, arg_stars, | ||||
|                                                           end_callback=end_callback, cancel_callback=cancel_callback))) | ||||
|  | ||||
|     async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None, | ||||
|                   end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: | ||||
|   | ||||
| @@ -20,7 +20,6 @@ Unittests for the `asyncio_taskpool.pool` module. | ||||
|  | ||||
| from asyncio.exceptions import CancelledError | ||||
| from asyncio.locks import Semaphore | ||||
| from asyncio.queues import QueueEmpty | ||||
| from datetime import datetime | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call | ||||
| @@ -555,28 +554,6 @@ class TaskPoolTestCase(CommonTestCase): | ||||
|         check_assertions(generated_name, output) | ||||
|         mock__generate_group_name.assert_called_once_with('apply', mock_func) | ||||
|  | ||||
|     @patch.object(pool, 'Queue') | ||||
|     async def test__queue_producer(self, mock_queue_cls: MagicMock): | ||||
|         mock_put = AsyncMock() | ||||
|         mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put) | ||||
|         item1, item2, item3 = FOO, 420, 69 | ||||
|         arg_iter = iter([item1, item2, item3]) | ||||
|         self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) | ||||
|         mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)]) | ||||
|         with self.assertRaises(StopIteration): | ||||
|             next(arg_iter) | ||||
|  | ||||
|         mock_put.reset_mock() | ||||
|  | ||||
|         mock_put.side_effect = [CancelledError, None] | ||||
|         arg_iter = iter([item1, item2, item3]) | ||||
|         mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty] | ||||
|         self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) | ||||
|         mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)]) | ||||
|         mock_queue.get_nowait.assert_has_calls([call(), call(), call()]) | ||||
|         mock_queue.item_processed.assert_has_calls([call(), call()]) | ||||
|         self.assertListEqual([item2, item3], list(arg_iter)) | ||||
|  | ||||
|     @patch.object(pool, 'execute_optional') | ||||
|     async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): | ||||
|         semaphore, mock_end_cb = Semaphore(1), MagicMock() | ||||
| @@ -592,30 +569,28 @@ class TaskPoolTestCase(CommonTestCase): | ||||
|     @patch.object(pool, 'Semaphore') | ||||
|     async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock, | ||||
|                                    mock__start_task: AsyncMock, mock_star_function: MagicMock): | ||||
|         mock_semaphore_cls.return_value = semaphore = Semaphore(3) | ||||
|         n = 2 | ||||
|         mock_semaphore_cls.return_value = semaphore = Semaphore(n) | ||||
|         mock__get_map_end_callback.return_value = map_cb = MagicMock() | ||||
|         awaitable = 'totally an awaitable' | ||||
|         mock_star_function.side_effect = [awaitable, awaitable, Exception()] | ||||
|         mock_star_function.side_effect = [awaitable, Exception(), awaitable] | ||||
|         arg1, arg2, bad = 123456789, 'function argument', None | ||||
|         mock_q_maxsize = 3 | ||||
|         mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]), | ||||
|                            __aexit__=AsyncMock(), maxsize=mock_q_maxsize) | ||||
|         args = [arg1, bad, arg2] | ||||
|         group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) | ||||
|         # We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then | ||||
|         # acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3, | ||||
|         # at the end of the loop, we expect it be locked. | ||||
|         self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) | ||||
|         # We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then | ||||
|         # acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked. | ||||
|         self.assertTrue(semaphore.locked()) | ||||
|         mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) | ||||
|         mock_semaphore_cls.assert_called_once_with(n) | ||||
|         mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) | ||||
|         mock__start_task.assert_has_awaits(2 * [ | ||||
|             call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) | ||||
|         ]) | ||||
|         mock_star_function.assert_has_calls([ | ||||
|             call(mock_func, arg1, arg_stars=stars), | ||||
|             call(mock_func, arg2, arg_stars=stars), | ||||
|             call(mock_func, bad, arg_stars=stars) | ||||
|             call(mock_func, bad, arg_stars=stars), | ||||
|             call(mock_func, arg2, arg_stars=stars) | ||||
|         ]) | ||||
|  | ||||
|         mock_semaphore_cls.reset_mock() | ||||
| @@ -626,61 +601,53 @@ class TaskPoolTestCase(CommonTestCase): | ||||
|         # With a CancelledError thrown while starting a task: | ||||
|         mock_semaphore_cls.return_value = semaphore = Semaphore(1) | ||||
|         mock_star_function.side_effect = CancelledError() | ||||
|         mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize) | ||||
|         self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) | ||||
|         self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) | ||||
|         self.assertFalse(semaphore.locked()) | ||||
|         mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) | ||||
|         mock_semaphore_cls.assert_called_once_with(n) | ||||
|         mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) | ||||
|         mock__start_task.assert_not_called() | ||||
|         mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) | ||||
|  | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) | ||||
|     @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) | ||||
|     @patch.object(pool, 'Queue') | ||||
|     @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) | ||||
|     @patch.object(pool, 'TaskGroupRegister') | ||||
|     @patch.object(pool.BaseTaskPool, '_check_start') | ||||
|     async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, | ||||
|                         mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock): | ||||
|     async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock, | ||||
|                         mock_create_task: MagicMock): | ||||
|         mock_group_reg = set_up_mock_group_register(mock_reg_cls) | ||||
|         mock_queue_cls.return_value = mock_q = MagicMock() | ||||
|         mock__queue_producer.return_value = fake_producer = object() | ||||
|         mock__queue_consumer.return_value = fake_consumer = object() | ||||
|         fake_task1, fake_task2 = object(), object() | ||||
|         mock_create_task.side_effect = [fake_task1, fake_task2] | ||||
|         mock__arg_consumer.return_value = fake_consumer = object() | ||||
|         mock_create_task.return_value = fake_task = object() | ||||
|  | ||||
|         group_name, group_size = 'onetwothree', 0 | ||||
|         group_name, n = 'onetwothree', 0 | ||||
|         func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|             await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|  | ||||
|         mock__check_start.reset_mock() | ||||
|  | ||||
|         group_size = 1234 | ||||
|         n = 1234 | ||||
|         self.task_pool._task_groups = {group_name: MagicMock()} | ||||
|  | ||||
|         with self.assertRaises(exceptions.InvalidGroupName): | ||||
|             await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|             await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|  | ||||
|         mock__check_start.reset_mock() | ||||
|  | ||||
|         self.task_pool._task_groups.clear() | ||||
|         self.task_pool._before_gathering = [] | ||||
|  | ||||
|         self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)) | ||||
|         self.assertIsNone(await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|         mock_reg_cls.assert_called_once_with() | ||||
|         self.task_pool._task_groups[group_name] = mock_group_reg | ||||
|         mock_group_reg.__aenter__.assert_awaited_once_with() | ||||
|         mock_queue_cls.assert_called_once_with(maxsize=group_size) | ||||
|         mock__queue_producer.assert_called_once() | ||||
|         mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb) | ||||
|         mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)]) | ||||
|         self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name]) | ||||
|         mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars, | ||||
|                                                    end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock_create_task.assert_called_once_with(fake_consumer) | ||||
|         self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) | ||||
|         mock_group_reg.__aexit__.assert_awaited_once() | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_map') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user