diff --git a/setup.cfg b/setup.cfg index 934e13a..c75772a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.1.4 +version = 0.1.5 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index de6a589..a89876c 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -498,38 +498,47 @@ class TaskPool(BaseTaskPool): await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) await execute_optional(end_callback, args=(task_id,)) - def _fill_args_queue(self, q: Queue, args_iter: ArgsT, num_tasks: int) -> None: + def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue: """ Helper function for `_map()`. - Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to the arguments queue `q`. + Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`. + The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned. - If the iterable contains less than `num_tasks` elements, nothing else happens. - Otherwise the `_queue_producer` is started with the arguments queue and and iterator of the remaining arguments. + If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer` + is started as a separate task with the arguments queue and and iterator of the remaining arguments. Args: - q: - The (empty) new `asyncio.Queue` to hold the function arguments passed as `args_iter`. args_iter: The iterable of function arguments passed into `_map()` to use for creating the new tasks. num_tasks: The maximum number of the new tasks to run concurrently that was passed into `_map()`. + + Returns: + The newly created and filled arguments queue for spawning new tasks. """ + # Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run + # concurrently because the size of the queue is what will determine the number of immediately started tasks in + # the `_map()` method and each of those will only ever start (at most) one other task upon ending. + args_queue = Queue(maxsize=num_tasks) + self._before_gathering.append(join_queue(args_queue)) args_iter = iter(args_iter) try: # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of # tasks, which will be at most `num_tasks` (meaning the queue will be full). for i in range(num_tasks): - q.put_nowait(next(args_iter)) + args_queue.put_nowait(next(args_iter)) except StopIteration: # If we get here, this means that the number of elements in the arguments iterator was less than the - # specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue. + # specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue. # The `_queue_producer` won't be necessary, since we already put all the elements in the queue. - return - # There may be more elements in the arguments iterator, so we need the `_queue_producer`. - # It will have exclusive access to the `args_iter` from now on. - # If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting - # the next item in it. - create_task(self._queue_producer(q, args_iter)) + pass + else: + # There may be more elements in the arguments iterator, so we need the `_queue_producer`. + # It will have exclusive access to the `args_iter` from now on. + # Since the queue is full already, it will wait until one of the tasks in the first batch ends, + # before putting the next item in it. + create_task(self._queue_producer(args_queue, args_iter)) + return args_queue async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: @@ -542,7 +551,6 @@ class TaskPool(BaseTaskPool): This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable. - The queue's `join()` method is added to the pool's `_before_gathering` list. Args: func: @@ -565,9 +573,7 @@ class TaskPool(BaseTaskPool): """ if not self.is_open: raise exceptions.PoolIsClosed("Cannot start new tasks") - args_queue = Queue(maxsize=num_tasks) - self._before_gathering.append(join_queue(args_queue)) - self._fill_args_queue(args_queue, args_iter, num_tasks) + args_queue = self._set_up_args_queue(args_iter, num_tasks) for _ in range(args_queue.qsize()): # This is where blocking can occur, if the pool is full. await self._queue_consumer(args_queue, func, diff --git a/tests/test_pool.py b/tests/test_pool.py index 110f74f..6e7e3fe 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,7 +1,9 @@ import asyncio from asyncio.exceptions import CancelledError +from asyncio.queues import Queue from unittest import IsolatedAsyncioTestCase from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call +from typing import Type from asyncio_taskpool import pool, exceptions @@ -14,7 +16,12 @@ class TestException(Exception): pass -class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): +class CommonTestCase(IsolatedAsyncioTestCase): + TEST_CLASS: Type[pool.BaseTaskPool] = pool.BaseTaskPool + TEST_POOL_SIZE: int = 420 + TEST_POOL_NAME: str = 'test123' + + task_pool: pool.BaseTaskPool log_lvl: int @classmethod @@ -26,35 +33,38 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): def tearDownClass(cls) -> None: pool.log.setLevel(cls.log_lvl) - def setUp(self) -> None: - self._pools = getattr(pool.BaseTaskPool, '_pools') + def get_task_pool_init_params(self) -> dict: + return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME} - # These three methods are called during initialization, so we mock them by default during setup - self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') - self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock) - self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') + def setUp(self) -> None: + self._pools = self.TEST_CLASS._pools + # These three methods are called during initialization, so we mock them by default during setup: + self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool') + self.pool_size_patcher = patch.object(self.TEST_CLASS, 'pool_size', new_callable=PropertyMock) + self.dunder_str_patcher = patch.object(self.TEST_CLASS, '__str__') self.mock__add_pool = self._add_pool_patcher.start() self.mock_pool_size = self.pool_size_patcher.start() - self.mock___str__ = self.__str___patcher.start() + self.mock___str__ = self.dunder_str_patcher.start() self.mock__add_pool.return_value = self.mock_idx = 123 self.mock___str__.return_value = self.mock_str = 'foobar' - # Test pool parameters: - self.test_pool_size, self.test_pool_name = 420, 'test123' - self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name) + self.task_pool = self.TEST_CLASS(**self.get_task_pool_init_params()) def tearDown(self) -> None: - setattr(pool.TaskPool, '_pools', self._pools) + self.TEST_CLASS._pools.clear() self._add_pool_patcher.stop() self.pool_size_patcher.stop() - self.__str___patcher.stop() + self.dunder_str_patcher.stop() + + +class BaseTaskPoolTestCase(CommonTestCase): def test__add_pool(self): self.assertListEqual(EMPTY_LIST, self._pools) self._add_pool_patcher.stop() - output = pool.TaskPool._add_pool(self.task_pool) + output = pool.BaseTaskPool._add_pool(self.task_pool) self.assertEqual(0, output) - self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) + self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools) def test_init(self): self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) @@ -66,26 +76,26 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.assertEqual(0, self.task_pool._num_cancelled) self.assertEqual(0, self.task_pool._num_ended) self.assertEqual(self.mock_idx, self.task_pool._idx) - self.assertEqual(self.test_pool_name, self.task_pool._name) + self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.mock__add_pool.assert_called_once_with(self.task_pool) - self.mock_pool_size.assert_called_once_with(self.test_pool_size) + self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE) self.mock___str__.assert_called_once_with() def test___str__(self): - self.__str___patcher.stop() - expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}' + self.dunder_str_patcher.stop() + expected_str = f'{pool.BaseTaskPool.__name__}-{self.TEST_POOL_NAME}' self.assertEqual(expected_str, str(self.task_pool)) - setattr(self.task_pool, '_name', None) + self.task_pool._name = None expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' self.assertEqual(expected_str, str(self.task_pool)) def test_pool_size(self): self.pool_size_patcher.stop() - self.task_pool._pool_size = self.test_pool_size - self.assertEqual(self.test_pool_size, self.task_pool.pool_size) + self.task_pool._pool_size = self.TEST_POOL_SIZE + self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size) with self.assertRaises(ValueError): self.task_pool.pool_size = -1 @@ -377,3 +387,190 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.task_pool._cancelled = {456: mock_cancelled_func()} self.task_pool._running = {789: mock_running_func()} check_assertions(await self.task_pool.gather(return_exceptions=True)) + + +class TaskPoolTestCase(CommonTestCase): + TEST_CLASS = pool.TaskPool + task_pool: pool.TaskPool + + @patch.object(pool.TaskPool, '_start_task') + async def test__apply_one(self, mock__start_task: AsyncMock): + mock__start_task.return_value = expected_output = 12345 + mock_awaitable = MagicMock() + mock_func = MagicMock(return_value=mock_awaitable) + args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} + end_cb, cancel_cb = MagicMock(), MagicMock() + output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb) + self.assertEqual(expected_output, output) + mock_func.assert_called_once_with(*args, **kwargs) + mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) + + mock_func.reset_mock() + mock__start_task.reset_mock() + + output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb) + self.assertEqual(expected_output, output) + mock_func.assert_called_once_with(*args) + mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) + + @patch.object(pool.TaskPool, '_apply_one') + async def test_apply(self, mock__apply_one: AsyncMock): + mock__apply_one.return_value = mock_id = 67890 + mock_func, num = MagicMock(), 3 + args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} + end_cb, cancel_cb = MagicMock(), MagicMock() + expected_output = num * [mock_id] + output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb) + self.assertEqual(expected_output, output) + mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)]) + + async def test__queue_producer(self): + mock_put = AsyncMock() + mock_q = MagicMock(put=mock_put) + args = (FOO, BAR, 123) + assert not self.task_pool._interrupt_flag.is_set() + self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) + mock_put.assert_has_awaits([call(arg) for arg in args]) + mock_put.reset_mock() + self.task_pool._interrupt_flag.set() + self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) + mock_put.assert_not_awaited() + + @patch.object(pool, 'partial') + @patch.object(pool, 'star_function') + @patch.object(pool.TaskPool, '_start_task') + async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock, + mock_partial: MagicMock): + mock_partial.return_value = queue_callback = 'not really' + mock_star_function.return_value = awaitable = 'totally an awaitable' + q, arg = Queue(), 420.69 + q.put_nowait(arg) + mock_func, stars = MagicMock(), 3 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb)) + self.assertTrue(q.empty()) + mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True, + end_callback=queue_callback, cancel_callback=cancel_cb) + mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars) + mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool, + q=q, func=mock_func, arg_stars=stars, + end_callback=end_cb, cancel_callback=cancel_cb) + mock__start_task.reset_mock() + mock_star_function.reset_mock() + mock_partial.reset_mock() + + self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb)) + self.assertTrue(q.empty()) + mock__start_task.assert_not_awaited() + mock_star_function.assert_not_called() + mock_partial.assert_not_called() + + @patch.object(pool, 'execute_optional') + @patch.object(pool.TaskPool, '_queue_consumer') + async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock): + task_id, mock_q = 420, MagicMock() + mock_func, stars = MagicMock(), 3 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_func, stars, end_cb, cancel_cb)) + mock__queue_consumer.assert_awaited_once_with(mock_q, mock_func, stars, + end_callback=end_cb, cancel_callback=cancel_cb) + mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,)) + + @patch.object(pool, 'iter') + @patch.object(pool, 'create_task') + @patch.object(pool, 'join_queue', new_callable=MagicMock) + @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) + async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock, + mock_create_task: MagicMock, mock_iter: MagicMock): + args, num_tasks = (FOO, BAR, 1, 2, 3), 2 + mock_join_queue.return_value = mock_join = 'awaitable' + mock_iter.return_value = args_iter = iter(args) + mock__queue_producer.return_value = mock_producer_coro = 'very awaitable' + output_q = self.task_pool._set_up_args_queue(args, num_tasks) + self.assertIsInstance(output_q, Queue) + self.assertEqual(num_tasks, output_q.qsize()) + for arg in args[:num_tasks]: + self.assertEqual(arg, output_q.get_nowait()) + self.assertTrue(output_q.empty()) + for arg in args[num_tasks:]: + self.assertEqual(arg, next(args_iter)) + with self.assertRaises(StopIteration): + next(args_iter) + self.assertListEqual([mock_join], self.task_pool._before_gathering) + mock_join_queue.assert_called_once_with(output_q) + mock__queue_producer.assert_called_once_with(output_q, args_iter) + mock_create_task.assert_called_once_with(mock_producer_coro) + + self.task_pool._before_gathering.clear() + mock_join_queue.reset_mock() + mock__queue_producer.reset_mock() + mock_create_task.reset_mock() + + num_tasks = 6 + mock_iter.return_value = args_iter = iter(args) + output_q = self.task_pool._set_up_args_queue(args, num_tasks) + self.assertIsInstance(output_q, Queue) + self.assertEqual(len(args), output_q.qsize()) + for arg in args: + self.assertEqual(arg, output_q.get_nowait()) + self.assertTrue(output_q.empty()) + with self.assertRaises(StopIteration): + next(args_iter) + self.assertListEqual([mock_join], self.task_pool._before_gathering) + mock_join_queue.assert_called_once_with(output_q) + mock__queue_producer.assert_not_called() + mock_create_task.assert_not_called() + + @patch.object(pool.TaskPool, '_queue_consumer') + @patch.object(pool.TaskPool, '_set_up_args_queue') + @patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock) + async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock, + mock__queue_consumer: AsyncMock): + qsize = 4 + mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize)) + + mock_func, stars = MagicMock(), 3 + args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2 + end_cb, cancel_cb = MagicMock(), MagicMock() + + mock_is_open.return_value = False + with self.assertRaises(exceptions.PoolIsClosed): + await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb) + mock_is_open.assert_called_once_with() + mock__set_up_args_queue.assert_not_called() + mock__queue_consumer.assert_not_awaited() + + mock_is_open.reset_mock() + + mock_is_open.return_value = True + self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)) + mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks) + mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_func, arg_stars=stars, + end_callback=end_cb, cancel_callback=cancel_cb)]) + + @patch.object(pool.TaskPool, '_map') + async def test_map(self, mock__map: AsyncMock): + mock_func = MagicMock() + arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb)) + mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks, + end_callback=end_cb, cancel_callback=cancel_cb) + + @patch.object(pool.TaskPool, '_map') + async def test_starmap(self, mock__map: AsyncMock): + mock_func = MagicMock() + args_iter, num_tasks = ([FOO], [BAR]), 2 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb)) + mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks, + end_callback=end_cb, cancel_callback=cancel_cb) + + @patch.object(pool.TaskPool, '_map') + async def test_doublestarmap(self, mock__map: AsyncMock): + mock_func = MagicMock() + kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2 + end_cb, cancel_cb = MagicMock(), MagicMock() + self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb)) + mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, + end_callback=end_cb, cancel_callback=cancel_cb)