From 54e5bfa8a07b10867c6b777c14af1c0d001c0085 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sat, 26 Mar 2022 10:49:32 +0100 Subject: [PATCH] drastically simplified meta-task internals --- src/asyncio_taskpool/pool.py | 110 +++++++++++------------------------ tests/test_pool.py | 85 +++++++++------------------ 2 files changed, 60 insertions(+), 135 deletions(-) diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 34df593..5f7a122 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -31,15 +31,13 @@ import logging from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.exceptions import CancelledError from asyncio.locks import Semaphore -from asyncio.queues import QueueEmpty from asyncio.tasks import Task, create_task, gather from contextlib import suppress from datetime import datetime from math import inf -from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set, Union +from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from . import exceptions -from .queue_context import Queue from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT from .internals.group_register import TaskGroupRegister from .internals.helpers import execute_optional, star_function @@ -491,8 +489,6 @@ class TaskPool(BaseTaskPool): Adding tasks blocks **only if** the pool is full at that moment. """ - _QUEUE_END_SENTINEL = object() - def __init__(self, pool_size: int = inf, name: str = None) -> None: super().__init__(pool_size=pool_size, name=name) # In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of @@ -714,34 +710,6 @@ class TaskPool(BaseTaskPool): await task return group_name - @classmethod - async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None: - """ - Keeps the arguments queue from :meth:`_map` full as long as the iterator has elements. - - Intended to be run as a meta task of a specific group. - - Args: - arg_queue: The queue of function arguments to consume for starting a new task. - arg_iter: The iterator of function arguments to put into the queue. - group_name: Name of the task group associated with this producer. - """ - try: - for arg in arg_iter: - await arg_queue.put(arg) # This blocks as long as the queue is full. - except CancelledError: - # This means that no more tasks are supposed to be created from this `_map()` call; - # thus, we can immediately drain the entire queue and forget about the rest of the arguments. - log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) - while True: - try: - arg_queue.get_nowait() - arg_queue.item_processed() - except QueueEmpty: - return - finally: - await arg_queue.put(cls._QUEUE_END_SENTINEL) - @staticmethod def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: """Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`.""" @@ -750,23 +718,25 @@ class TaskPool(BaseTaskPool): await execute_optional(actual_end_callback, args=(task_id,)) return release_callback - async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0, - end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: + async def _arg_consumer(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, + arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ - Consumes arguments from the queue from :meth:`_map` and keeps a limited number of tasks working on them. + Consumes arguments from :meth:`_map` and keeps a limited number of tasks working on them. - The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before - a new task can be started, and which must be released when one of these tasks ends. + The `group_size` acts as the limiting value of an internal semaphore, which must be acquired before a new task + can be started, and which must be released when one of these tasks ends. Intended to be run as a meta task of a specific group. Args: - arg_queue: - The queue of function arguments to consume for starting a new task. group_name: Name of the associated task group; passed into :meth:`_start_task`. + group_size: + The maximum number new tasks spawned by this method to run concurrently. func: The coroutine function to use for spawning the new tasks within the task pool. + arg_iter: + The iterable of arguments; each element is to be passed into a `func` call when spawning a new task. arg_stars (optional): Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2. end_callback (optional): @@ -776,29 +746,27 @@ class TaskPool(BaseTaskPool): The callback that was specified to execute after cancellation of the task (and the next one). It is run with the task's ID as its only positional argument. """ - map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in :meth:`_map` + map_semaphore = Semaphore(group_size) release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) - while True: - # The following line blocks **only if** the number of running tasks spawned by this method has reached the - # specified maximum as determined in :meth:`_map`. + for next_arg in arg_iter: + # When the number of running tasks spawned by this method reaches the specified maximum, + # this next line will block, until one of them ends and releases the semaphore. await map_semaphore.acquire() - # We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called. - async with arg_queue as next_arg: - if next_arg is self._QUEUE_END_SENTINEL: - # The :meth:`_queue_producer` either reached the last argument or was cancelled. - return - try: - await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, - ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) - except CancelledError: - map_semaphore.release() - return - except Exception as e: - # This means an exception occurred during task **creation**, meaning no task has been created. - # It does not imply an error within the task itself. - log.exception("%s occurred while trying to create task: %s(%s%s)", - str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) - map_semaphore.release() + try: + await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, + ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) + except CancelledError: + # This means that no more tasks are supposed to be created from this `arg_iter`; + # thus, we can forget about the rest of the arguments. + log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) + map_semaphore.release() + return + except Exception as e: + # This means an exception occurred during task **creation**, meaning no task has been created. + # It does not imply an error within the task itself. + log.exception("%s occurred while trying to create task: %s(%s%s)", + str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) + map_semaphore.release() async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: @@ -815,11 +783,9 @@ class TaskPool(BaseTaskPool): of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course). - This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`. - Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the - aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does - not mean that any task was started or that any number of tasks will start soon, as this is solely determined by - the :attr:`BaseTaskPool.pool_size` and the `group_size`. + Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just + because this method returns immediately, this does not mean that any task was started or that any number of + tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and the `group_size`. Args: group_name: @@ -850,17 +816,9 @@ class TaskPool(BaseTaskPool): raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") self._task_groups[group_name] = group_reg = TaskGroupRegister() async with group_reg: - # Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the - # `_queue_producer()`; that way an argument - # TODO: Perhaps this can be simplified to just one meta-task with no need for a queue. - # The limiting factor honoring the group size is already the semaphore in the queue consumer; - # Try to write this without a producer, instead consuming the `arg_iter` directly. - arg_queue = Queue(maxsize=group_size) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) - # Start the producer and consumer meta tasks. - meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name))) - meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars, - end_callback, cancel_callback))) + meta_tasks.add(create_task(self._arg_consumer(group_name, group_size, func, arg_iter, arg_stars, + end_callback=end_callback, cancel_callback=cancel_callback))) async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: diff --git a/tests/test_pool.py b/tests/test_pool.py index d2a82b5..1e0b94c 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -20,7 +20,6 @@ Unittests for the `asyncio_taskpool.pool` module. from asyncio.exceptions import CancelledError from asyncio.locks import Semaphore -from asyncio.queues import QueueEmpty from datetime import datetime from unittest import IsolatedAsyncioTestCase from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call @@ -555,28 +554,6 @@ class TaskPoolTestCase(CommonTestCase): check_assertions(generated_name, output) mock__generate_group_name.assert_called_once_with('apply', mock_func) - @patch.object(pool, 'Queue') - async def test__queue_producer(self, mock_queue_cls: MagicMock): - mock_put = AsyncMock() - mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put) - item1, item2, item3 = FOO, 420, 69 - arg_iter = iter([item1, item2, item3]) - self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) - mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)]) - with self.assertRaises(StopIteration): - next(arg_iter) - - mock_put.reset_mock() - - mock_put.side_effect = [CancelledError, None] - arg_iter = iter([item1, item2, item3]) - mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty] - self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR)) - mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)]) - mock_queue.get_nowait.assert_has_calls([call(), call(), call()]) - mock_queue.item_processed.assert_has_calls([call(), call()]) - self.assertListEqual([item2, item3], list(arg_iter)) - @patch.object(pool, 'execute_optional') async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): semaphore, mock_end_cb = Semaphore(1), MagicMock() @@ -592,30 +569,28 @@ class TaskPoolTestCase(CommonTestCase): @patch.object(pool, 'Semaphore') async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock, mock__start_task: AsyncMock, mock_star_function: MagicMock): - mock_semaphore_cls.return_value = semaphore = Semaphore(3) + n = 2 + mock_semaphore_cls.return_value = semaphore = Semaphore(n) mock__get_map_end_callback.return_value = map_cb = MagicMock() awaitable = 'totally an awaitable' - mock_star_function.side_effect = [awaitable, awaitable, Exception()] + mock_star_function.side_effect = [awaitable, Exception(), awaitable] arg1, arg2, bad = 123456789, 'function argument', None - mock_q_maxsize = 3 - mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]), - __aexit__=AsyncMock(), maxsize=mock_q_maxsize) + args = [arg1, bad, arg2] group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) - # We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then - # acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3, - # at the end of the loop, we expect it be locked. + self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) + # We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then + # acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked. self.assertTrue(semaphore.locked()) - mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) + mock_semaphore_cls.assert_called_once_with(n) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__start_task.assert_has_awaits(2 * [ call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) ]) mock_star_function.assert_has_calls([ call(mock_func, arg1, arg_stars=stars), - call(mock_func, arg2, arg_stars=stars), - call(mock_func, bad, arg_stars=stars) + call(mock_func, bad, arg_stars=stars), + call(mock_func, arg2, arg_stars=stars) ]) mock_semaphore_cls.reset_mock() @@ -626,61 +601,53 @@ class TaskPoolTestCase(CommonTestCase): # With a CancelledError thrown while starting a task: mock_semaphore_cls.return_value = semaphore = Semaphore(1) mock_star_function.side_effect = CancelledError() - mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize) - self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) + self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) self.assertFalse(semaphore.locked()) - mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) + mock_semaphore_cls.assert_called_once_with(n) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__start_task.assert_not_called() mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) @patch.object(pool, 'create_task') - @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) - @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) - @patch.object(pool, 'Queue') + @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) @patch.object(pool, 'TaskGroupRegister') @patch.object(pool.BaseTaskPool, '_check_start') - async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, - mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock): + async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock, + mock_create_task: MagicMock): mock_group_reg = set_up_mock_group_register(mock_reg_cls) - mock_queue_cls.return_value = mock_q = MagicMock() - mock__queue_producer.return_value = fake_producer = object() - mock__queue_consumer.return_value = fake_consumer = object() - fake_task1, fake_task2 = object(), object() - mock_create_task.side_effect = [fake_task1, fake_task2] + mock__arg_consumer.return_value = fake_consumer = object() + mock_create_task.return_value = fake_task = object() - group_name, group_size = 'onetwothree', 0 + group_name, n = 'onetwothree', 0 func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 end_cb, cancel_cb = MagicMock(), MagicMock() with self.assertRaises(ValueError): - await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) + await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) mock__check_start.assert_called_once_with(function=func) mock__check_start.reset_mock() - group_size = 1234 + n = 1234 self.task_pool._task_groups = {group_name: MagicMock()} with self.assertRaises(exceptions.InvalidGroupName): - await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) + await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) mock__check_start.assert_called_once_with(function=func) mock__check_start.reset_mock() self.task_pool._task_groups.clear() - self.task_pool._before_gathering = [] - self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)) + self.assertIsNone(await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)) mock__check_start.assert_called_once_with(function=func) mock_reg_cls.assert_called_once_with() self.task_pool._task_groups[group_name] = mock_group_reg mock_group_reg.__aenter__.assert_awaited_once_with() - mock_queue_cls.assert_called_once_with(maxsize=group_size) - mock__queue_producer.assert_called_once() - mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb) - mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)]) - self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name]) + mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars, + end_callback=end_cb, cancel_callback=cancel_cb) + mock_create_task.assert_called_once_with(fake_consumer) + self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) mock_group_reg.__aexit__.assert_awaited_once() @patch.object(pool.TaskPool, '_map')