Improve exception handling in `_start_num`

This commit is contained in:
Daniil Fajnberg 2022-04-18 13:45:15 +02:00
parent d047b99119
commit 36527ccffc
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
2 changed files with 41 additions and 18 deletions

View File

@ -625,7 +625,7 @@ class TaskPool(BaseTaskPool):
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
@ -962,13 +962,24 @@ class SimpleTaskPool(BaseTaskPool):
async def _start_num(self, num: int, group_name: str) -> None:
"""Starts `num` new tasks in group `group_name`."""
start_coroutines = (
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num)
)
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines)
for i in range(num):
try:
coroutine = self._func(*self._args, **self._kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), str(self), self._func.__name__,
repr(self._args), repr(self._kwargs))
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
cancel_callback=self._cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def start(self, num: int) -> str:
"""

View File

@ -790,18 +790,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
@patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_num(self, mock__start_task: AsyncMock):
fake_coroutine = object()
self.task_pool._func = MagicMock(return_value=fake_coroutine)
num = 3
group_name = FOO + BAR + 'abc'
mock_awaitable1, mock_awaitable2 = object(), object()
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
num = 3
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(num * [
call(*self.task_pool._args, **self.task_pool._kwargs)
])
mock__start_task.assert_has_awaits(num * [
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
])
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
call_kw = {
'group_name': group_name,
'end_callback': self.task_pool._end_callback,
'cancel_callback': self.task_pool._cancel_callback
}
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
self.task_pool._func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())