improved exception handling/logging in `apply` meta task

This commit is contained in:
Daniil Fajnberg 2022-03-31 11:04:31 +02:00
parent 80fc91ec47
commit 0daed04167
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
3 changed files with 72 additions and 36 deletions

View File

@ -23,7 +23,7 @@ This module should **not** be considered part of the public API.
from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
@ -32,7 +32,7 @@ ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
CoroutineFunc = Callable[[...], Coroutine]
EndCB = Callable
CancelCB = Callable

View File

@ -528,6 +528,8 @@ class BaseTaskPool:
self._tasks_cancelled.clear()
self._tasks_running.clear()
self._closed = True
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server.
class TaskPool(BaseTaskPool):
@ -566,36 +568,51 @@ class TaskPool(BaseTaskPool):
return name
i += 1
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
group_name:
Name of the task group to add the new task to.
Name of the task group to add the new tasks to.
func:
The coroutine function to be run as a task within the task pool.
The coroutine function to be run in `num` tasks within the task pool.
args (optional):
The positional arguments to pass into the function call.
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after the task has ended.
A callback to execute after each task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
A callback to execute after cancellation of each task.
It is run with the task's ID as its only positional argument.
"""
if kwargs is None:
kwargs = {}
# TODO: Add exception logging
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) for _ in range(num)))
for i in range(num):
try:
coroutine = func(*args, **kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned",
group_name, i, num)
coroutine.close()
return
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
@ -650,8 +667,8 @@ class TaskPool(BaseTaskPool):
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
return group_name
@staticmethod
@ -696,6 +713,8 @@ class TaskPool(BaseTaskPool):
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
# TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore.
# Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`).
try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)

View File

@ -455,49 +455,65 @@ class TaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output)
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock):
group_name = FOO + BAR
mock_awaitable = object()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
async def test__apply_spawner(self, mock__start_task: AsyncMock):
grp_name = FOO + BAR
mock_awaitable1, mock_awaitable2 = object(), object()
mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
mock__start_task.assert_has_awaits(3 * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args, **kw)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])
mock_func.reset_mock()
mock_func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args)])
mock__start_task.assert_has_awaits(num * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(2 * [call(*args)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
@patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start')
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object()
mock__apply_spawner.return_value = mock_apply_coroutine = object()
mock_create_task.return_value = fake_task = object()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._task_groups = {group_name: 'causes error'}
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=mock_func)
mock__apply_spawner.assert_not_called()
mock_create_task.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups = {}
def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
@ -507,7 +523,7 @@ class TaskPoolTestCase(CommonTestCase):
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
mock__apply_num.reset_mock()
mock__apply_spawner.reset_mock()
mock_create_task.reset_mock()
output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
@ -695,6 +711,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)