Compare commits

...

2 Commits

3 changed files with 145 additions and 72 deletions

View File

@ -23,7 +23,7 @@ This module should **not** be considered part of the public API.
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T') T = TypeVar('T')
@ -32,7 +32,7 @@ ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Coroutine]
EndCB = Callable EndCB = Callable
CancelCB = Callable CancelCB = Callable

View File

@ -326,6 +326,8 @@ class BaseTaskPool:
""" """
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock) self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
await self._enough_room.acquire() await self._enough_room.acquire()
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
# Possibly make use of the task group register for that.
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
async with group_reg: async with group_reg:
task_id = self._num_started task_id = self._num_started
@ -528,6 +530,8 @@ class BaseTaskPool:
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._closed = True self._closed = True
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server.
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):
@ -566,36 +570,50 @@ class TaskPool(BaseTaskPool):
return name return name
i += 1 i += 1
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
""" """
Creates a coroutine with the supplied arguments and runs it as a new task in the pool. Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks. This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args: Args:
group_name: group_name:
Name of the task group to add the new task to. Name of the task group to add the new tasks to.
func: func:
The coroutine function to be run as a task within the task pool. The coroutine function to be run in `num` tasks within the task pool.
args (optional): args (optional):
The positional arguments to pass into the function call. The positional arguments to pass into each function call.
kwargs (optional): kwargs (optional):
The keyword-arguments to pass into the function call. The keyword-arguments to pass into each function call.
num (optional): num (optional):
The number of tasks to spawn with the specified parameters. The number of tasks to spawn with the specified parameters.
end_callback (optional): end_callback (optional):
A callback to execute after the task has ended. A callback to execute after each task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
cancel_callback (optional): cancel_callback (optional):
A callback to execute after cancellation of the task. A callback to execute after cancellation of each task.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
""" """
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
# TODO: Add exception logging for i in range(num):
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback, try:
cancel_callback=cancel_callback) for _ in range(num))) coroutine = func(*args, **kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None, def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
@ -650,8 +668,8 @@ class TaskPool(BaseTaskPool):
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups.setdefault(group_name, TaskGroupRegister()) self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num, meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback))) end_callback=end_callback, cancel_callback=cancel_callback)))
return group_name return group_name
@staticmethod @staticmethod
@ -690,27 +708,31 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one). The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
""" """
map_semaphore = Semaphore(num_concurrent) semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
for next_arg in arg_iter: for i, next_arg in enumerate(arg_iter):
# When the number of running tasks spawned by this method reaches the specified maximum, semaphore_acquired = False
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
try: try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, coroutine = star_function(func, next_arg, arg_stars=arg_stars)
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# This means that no more tasks are supposed to be created from this `arg_iter`;
# thus, we can forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
map_semaphore.release()
return
except Exception as e: except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created. # This means there was probably something wrong with the function arguments.
# It does not imply an error within the task itself. log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
log.exception("%s occurred while trying to create task: %s(%s%s)", str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) continue
map_semaphore.release() try:
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
semaphore_acquired = await semaphore.acquire()
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
coroutine.close()
if semaphore_acquired:
semaphore.release()
return
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
@ -924,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool):
end_callback=self._end_callback, cancel_callback=self._cancel_callback) end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num) for _ in range(num)
) )
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines) await gather(*start_coroutines)
def start(self, num: int) -> str: def start(self, num: int) -> str:

View File

@ -455,49 +455,65 @@ class TaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
@patch.object(pool.TaskPool, '_start_task') @patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock): async def test__apply_spawner(self, mock__start_task: AsyncMock):
group_name = FOO + BAR grp_name = FOO + BAR
mock_awaitable = object() mock_awaitable1, mock_awaitable2 = object(), object()
mock_func = MagicMock(return_value=mock_awaitable) mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)]) mock_func.assert_has_calls(num * [call(*args, **kw)])
mock__start_task.assert_has_awaits(3 * [ mock__start_task.assert_has_awaits([
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
]) ])
mock_func.reset_mock() mock_func.reset_mock(side_effect=True)
mock__start_task.reset_mock() mock__start_task.reset_mock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb)) # Simulate cancellation while the second task is being started.
mock_func.assert_has_calls(num * [call(*args)]) mock__start_task.side_effect = [None, CancelledError, None]
mock__start_task.assert_has_awaits(num * [ mock_coroutine_to_close = MagicMock()
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(2 * [call(*args)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
]) ])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock()) @patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister') @patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start') @patch.object(pool.BaseTaskPool, '_check_start')
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123' mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object() mock__apply_spawner.return_value = mock_apply_coroutine = object()
mock_create_task.return_value = fake_task = object() mock_create_task.return_value = fake_task = object()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._task_groups = {group_name: 'causes error'}
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=mock_func)
mock__apply_spawner.assert_not_called()
mock_create_task.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups = {} self.task_pool._task_groups = {}
def check_assertions(_group_name, _output): def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output) self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func) mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine) mock_create_task.assert_called_once_with(mock_apply_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
@ -507,7 +523,7 @@ class TaskPoolTestCase(CommonTestCase):
mock__check_start.reset_mock() mock__check_start.reset_mock()
self.task_pool._task_groups.clear() self.task_pool._task_groups.clear()
mock__apply_num.reset_mock() mock__apply_spawner.reset_mock()
mock_create_task.reset_mock() mock_create_task.reset_mock()
output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
@ -532,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase):
n = 2 n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n) mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock() mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable' awaitable1, awaitable2 = 'totally an awaitable', object()
mock_star_function.side_effect = [awaitable, Exception(), awaitable] mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
arg1, arg2, bad = 123456789, 'function argument', None arg1, arg2, bad = 123456789, 'function argument', None
args = [arg1, bad, arg2] args = [arg1, bad, arg2]
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then # We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
self.assertTrue(semaphore.locked()) self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(n) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [ mock__start_task.assert_has_awaits([
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
]) ])
mock_star_function.assert_has_calls([ mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars), call(mock_func, arg1, arg_stars=stars),
@ -556,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase):
mock_semaphore_cls.reset_mock() mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock() mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock() mock__start_task.reset_mock()
mock_star_function.reset_mock() mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting a task: # With a CancelledError thrown while acquiring the semaphore:
mock_semaphore_cls.return_value = semaphore = Semaphore(1) mock_acquire = AsyncMock(side_effect=[True, CancelledError])
mock_star_function.side_effect = CancelledError() mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) mock_star_function.return_value = mock_coroutine = MagicMock()
self.assertFalse(semaphore.locked()) arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called() mock_star_function.assert_has_calls([
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
end_callback=map_cb, cancel_callback=cancel_cb)
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_not_called()
self.assertEqual(FOO, next(arg_it))
mock_acquire.reset_mock(side_effect=True)
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting the task:
mock__start_task.side_effect = [None, CancelledError]
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_has_awaits(2 * [
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_called_once_with()
self.assertEqual(FOO, next(arg_it))
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
@ -695,6 +744,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def tearDown(self) -> None: def tearDown(self) -> None:
self.base_class_init_patcher.stop() self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self): def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func) self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)