improved exception handling/logging in _map

This commit is contained in:
Daniil Fajnberg 2022-03-31 20:58:52 +02:00
parent 0daed04167
commit 56d38a3b44
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
2 changed files with 77 additions and 40 deletions

View File

@ -326,6 +326,8 @@ class BaseTaskPool:
""" """
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock) self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
await self._enough_room.acquire() await self._enough_room.acquire()
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
# Possibly make use of the task group register for that.
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
async with group_reg: async with group_reg:
task_id = self._num_started task_id = self._num_started
@ -609,8 +611,7 @@ class TaskPool(BaseTaskPool):
except CancelledError: except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any # Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately. # more tasks and can return immediately.
log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned", log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
group_name, i, num)
coroutine.close() coroutine.close()
return return
@ -707,29 +708,31 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one). The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
""" """
map_semaphore = Semaphore(num_concurrent) semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
for next_arg in arg_iter: for i, next_arg in enumerate(arg_iter):
# When the number of running tasks spawned by this method reaches the specified maximum, semaphore_acquired = False
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
# TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore.
# Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`).
try: try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, coroutine = star_function(func, next_arg, arg_stars=arg_stars)
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# This means that no more tasks are supposed to be created from this `arg_iter`;
# thus, we can forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
map_semaphore.release()
return
except Exception as e: except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created. # This means there was probably something wrong with the function arguments.
# It does not imply an error within the task itself. log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
log.exception("%s occurred while trying to create task: %s(%s%s)", str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) continue
map_semaphore.release() try:
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
semaphore_acquired = await semaphore.acquire()
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
coroutine.close()
if semaphore_acquired:
semaphore.release()
return
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
@ -943,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool):
end_callback=self._end_callback, cancel_callback=self._cancel_callback) end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num) for _ in range(num)
) )
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines) await gather(*start_coroutines)
def start(self, num: int) -> str: def start(self, num: int) -> str:

View File

@ -548,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase):
n = 2 n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n) mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock() mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable' awaitable1, awaitable2 = 'totally an awaitable', object()
mock_star_function.side_effect = [awaitable, Exception(), awaitable] mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
arg1, arg2, bad = 123456789, 'function argument', None arg1, arg2, bad = 123456789, 'function argument', None
args = [arg1, bad, arg2] args = [arg1, bad, arg2]
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then # We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
self.assertTrue(semaphore.locked()) self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(n) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [ mock__start_task.assert_has_awaits([
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
]) ])
mock_star_function.assert_has_calls([ mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars), call(mock_func, arg1, arg_stars=stars),
@ -572,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase):
mock_semaphore_cls.reset_mock() mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock() mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock() mock__start_task.reset_mock()
mock_star_function.reset_mock() mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting a task: # With a CancelledError thrown while acquiring the semaphore:
mock_semaphore_cls.return_value = semaphore = Semaphore(1) mock_acquire = AsyncMock(side_effect=[True, CancelledError])
mock_star_function.side_effect = CancelledError() mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) mock_star_function.return_value = mock_coroutine = MagicMock()
self.assertFalse(semaphore.locked()) arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called() mock_star_function.assert_has_calls([
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
end_callback=map_cb, cancel_callback=cancel_cb)
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_not_called()
self.assertEqual(FOO, next(arg_it))
mock_acquire.reset_mock(side_effect=True)
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting the task:
mock__start_task.side_effect = [None, CancelledError]
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_has_awaits(2 * [
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_called_once_with()
self.assertEqual(FOO, next(arg_it))
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)