added full test coverage of TaskPool and further improved its _map() method

This commit is contained in:
Daniil Fajnberg 2022-02-08 13:29:57 +01:00
parent d48b20818f
commit 63aab1a8f6
3 changed files with 244 additions and 41 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.1.4 version = 0.1.5
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -498,38 +498,47 @@ class TaskPool(BaseTaskPool):
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,)) await execute_optional(end_callback, args=(task_id,))
def _fill_args_queue(self, q: Queue, args_iter: ArgsT, num_tasks: int) -> None: def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
""" """
Helper function for `_map()`. Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to the arguments queue `q`. Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
If the iterable contains less than `num_tasks` elements, nothing else happens. If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
Otherwise the `_queue_producer` is started with the arguments queue and and iterator of the remaining arguments. is started as a separate task with the arguments queue and and iterator of the remaining arguments.
Args: Args:
q:
The (empty) new `asyncio.Queue` to hold the function arguments passed as `args_iter`.
args_iter: args_iter:
The iterable of function arguments passed into `_map()` to use for creating the new tasks. The iterable of function arguments passed into `_map()` to use for creating the new tasks.
num_tasks: num_tasks:
The maximum number of the new tasks to run concurrently that was passed into `_map()`. The maximum number of the new tasks to run concurrently that was passed into `_map()`.
Returns:
The newly created and filled arguments queue for spawning new tasks.
""" """
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
args_iter = iter(args_iter) args_iter = iter(args_iter)
try: try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full). # tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks): for i in range(num_tasks):
q.put_nowait(next(args_iter)) args_queue.put_nowait(next(args_iter))
except StopIteration: except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the # If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue. # specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue. # The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
return pass
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`. # There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on. # It will have exclusive access to the `args_iter` from now on.
# If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting # Since the queue is full already, it will wait until one of the tasks in the first batch ends,
# the next item in it. # before putting the next item in it.
create_task(self._queue_producer(q, args_iter)) create_task(self._queue_producer(args_queue, args_iter))
return args_queue
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
@ -542,7 +551,6 @@ class TaskPool(BaseTaskPool):
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable. It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
The queue's `join()` method is added to the pool's `_before_gathering` list.
Args: Args:
func: func:
@ -565,9 +573,7 @@ class TaskPool(BaseTaskPool):
""" """
if not self.is_open: if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
args_queue = Queue(maxsize=num_tasks) args_queue = self._set_up_args_queue(args_iter, num_tasks)
self._before_gathering.append(join_queue(args_queue))
self._fill_args_queue(args_queue, args_iter, num_tasks)
for _ in range(args_queue.qsize()): for _ in range(args_queue.qsize()):
# This is where blocking can occur, if the pool is full. # This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func, await self._queue_consumer(args_queue, func,

View File

@ -1,7 +1,9 @@
import asyncio import asyncio
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.queues import Queue
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
@ -14,7 +16,12 @@ class TestException(Exception):
pass pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): class CommonTestCase(IsolatedAsyncioTestCase):
TEST_CLASS: Type[pool.BaseTaskPool] = pool.BaseTaskPool
TEST_POOL_SIZE: int = 420
TEST_POOL_NAME: str = 'test123'
task_pool: pool.BaseTaskPool
log_lvl: int log_lvl: int
@classmethod @classmethod
@ -26,35 +33,38 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl) pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def get_task_pool_init_params(self) -> dict:
self._pools = getattr(pool.BaseTaskPool, '_pools') return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
# These three methods are called during initialization, so we mock them by default during setup def setUp(self) -> None:
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') self._pools = self.TEST_CLASS._pools
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock) # These three methods are called during initialization, so we mock them by default during setup:
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
self.pool_size_patcher = patch.object(self.TEST_CLASS, 'pool_size', new_callable=PropertyMock)
self.dunder_str_patcher = patch.object(self.TEST_CLASS, '__str__')
self.mock__add_pool = self._add_pool_patcher.start() self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start() self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start() self.mock___str__ = self.dunder_str_patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123 self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar' self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters: self.task_pool = self.TEST_CLASS(**self.get_task_pool_init_params())
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None: def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools) self.TEST_CLASS._pools.clear()
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.__str___patcher.stop() self.dunder_str_patcher.stop()
class BaseTaskPoolTestCase(CommonTestCase):
def test__add_pool(self): def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools) self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool) output = pool.BaseTaskPool._add_pool(self.task_pool)
self.assertEqual(0, output) self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self): def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
@ -66,26 +76,26 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(0, self.task_pool._num_cancelled) self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended) self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
def test___str__(self): def test___str__(self):
self.__str___patcher.stop() self.dunder_str_patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.TEST_POOL_NAME}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None) self.task_pool._name = None
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self): def test_pool_size(self):
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size self.task_pool._pool_size = self.TEST_POOL_SIZE
self.assertEqual(self.test_pool_size, self.task_pool.pool_size) self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.task_pool.pool_size = -1 self.task_pool.pool_size = -1
@ -377,3 +387,190 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._cancelled = {456: mock_cancelled_func()} self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()} self.task_pool._running = {789: mock_running_func()}
check_assertions(await self.task_pool.gather(return_exceptions=True)) check_assertions(await self.task_pool.gather(return_exceptions=True))
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 12345
mock_awaitable = MagicMock()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
mock_func.reset_mock()
mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_apply_one')
async def test_apply(self, mock__apply_one: AsyncMock):
mock__apply_one.return_value = mock_id = 67890
mock_func, num = MagicMock(), 3
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id]
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
async def test__queue_producer(self):
mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put)
args = (FOO, BAR, 123)
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_has_awaits([call(arg) for arg in args])
mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial')
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
mock_partial: MagicMock):
mock_partial.return_value = queue_callback = 'not really'
mock_star_function.return_value = awaitable = 'totally an awaitable'
q, arg = Queue(), 420.69
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
@patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer')
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
task_id, mock_q = 420, MagicMock()
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_func, stars, end_cb, cancel_cb))
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
mock_create_task: MagicMock, mock_iter: MagicMock):
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
mock_join_queue.return_value = mock_join = 'awaitable'
mock_iter.return_value = args_iter = iter(args)
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(num_tasks, output_q.qsize())
for arg in args[:num_tasks]:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
for arg in args[num_tasks:]:
self.assertEqual(arg, next(args_iter))
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
self.task_pool._before_gathering.clear()
mock_join_queue.reset_mock()
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
@patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock)
async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock,
mock__queue_consumer: AsyncMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
mock_is_open.return_value = False
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock_is_open.assert_called_once_with()
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_is_open.reset_mock()
mock_is_open.return_value = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)])
@patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock):
mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)