improved exception handling/logging in _map

This commit is contained in:
Daniil Fajnberg 2022-03-31 20:58:52 +02:00
parent 0daed04167
commit 56d38a3b44
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
2 changed files with 77 additions and 40 deletions

View File

@ -326,6 +326,8 @@ class BaseTaskPool:
"""
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
await self._enough_room.acquire()
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
# Possibly make use of the task group register for that.
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
async with group_reg:
task_id = self._num_started
@ -609,8 +611,7 @@ class TaskPool(BaseTaskPool):
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned",
group_name, i, num)
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
@ -707,29 +708,31 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument.
"""
map_semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
for next_arg in arg_iter:
semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
for i, next_arg in enumerate(arg_iter):
semaphore_acquired = False
try:
coroutine = star_function(func, next_arg, arg_stars=arg_stars)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
continue
try:
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
# TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore.
# Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`).
try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
semaphore_acquired = await semaphore.acquire()
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# This means that no more tasks are supposed to be created from this `arg_iter`;
# thus, we can forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
map_semaphore.release()
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
coroutine.close()
if semaphore_acquired:
semaphore.release()
return
except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created.
# It does not imply an error within the task itself.
log.exception("%s occurred while trying to create task: %s(%s%s)",
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
map_semaphore.release()
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
@ -943,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool):
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num)
)
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines)
def start(self, num: int) -> str:

View File

@ -548,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase):
n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable'
mock_star_function.side_effect = [awaitable, Exception(), awaitable]
awaitable1, awaitable2 = 'totally an awaitable', object()
mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
arg1, arg2, bad = 123456789, 'function argument', None
args = [arg1, bad, arg2]
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
mock__start_task.assert_has_awaits([
call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
])
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
@ -572,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase):
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting a task:
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
mock_star_function.side_effect = CancelledError()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
self.assertFalse(semaphore.locked())
# With a CancelledError thrown while acquiring the semaphore:
mock_acquire = AsyncMock(side_effect=[True, CancelledError])
mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
mock_star_function.return_value = mock_coroutine = MagicMock()
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called()
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
end_callback=map_cb, cancel_callback=cancel_cb)
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_not_called()
self.assertEqual(FOO, next(arg_it))
mock_acquire.reset_mock(side_effect=True)
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting the task:
mock__start_task.side_effect = [None, CancelledError]
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_has_awaits(2 * [
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_called_once_with()
self.assertEqual(FOO, next(arg_it))
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)