diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 9355802..320144c 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -326,6 +326,8 @@ class BaseTaskPool: """ self._check_start(awaitable=awaitable, ignore_lock=ignore_lock) await self._enough_room.acquire() + # TODO: Make sure that cancellation (group or pool) interrupts this method after context switching! + # Possibly make use of the task group register for that. group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) async with group_reg: task_id = self._num_started @@ -609,8 +611,7 @@ class TaskPool(BaseTaskPool): except CancelledError: # Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any # more tasks and can return immediately. - log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned", - group_name, i, num) + log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num) coroutine.close() return @@ -707,29 +708,31 @@ class TaskPool(BaseTaskPool): The callback that was specified to execute after cancellation of the task (and the next one). It is run with the task's ID as its only positional argument. """ - map_semaphore = Semaphore(num_concurrent) - release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback) - for next_arg in arg_iter: - # When the number of running tasks spawned by this method reaches the specified maximum, - # this next line will block, until one of them ends and releases the semaphore. - await map_semaphore.acquire() - # TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore. - # Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`). + semaphore = Semaphore(num_concurrent) + release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback) + for i, next_arg in enumerate(arg_iter): + semaphore_acquired = False try: - await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, - ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) - except CancelledError: - # This means that no more tasks are supposed to be created from this `arg_iter`; - # thus, we can forget about the rest of the arguments. - log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name) - map_semaphore.release() - return + coroutine = star_function(func, next_arg, arg_stars=arg_stars) except Exception as e: - # This means an exception occurred during task **creation**, meaning no task has been created. - # It does not imply an error within the task itself. - log.exception("%s occurred while trying to create task: %s(%s%s)", - str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) - map_semaphore.release() + # This means there was probably something wrong with the function arguments. + log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)", + str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg)) + continue + try: + # When the number of running tasks spawned by this method reaches the specified maximum, + # this next line will block, until one of them ends and releases the semaphore. + semaphore_acquired = await semaphore.acquire() + await self._start_task(coroutine, group_name=group_name, ignore_lock=True, + end_callback=release_cb, cancel_callback=cancel_callback) + except CancelledError: + # Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any + # more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.) + log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i) + coroutine.close() + if semaphore_acquired: + semaphore.release() + return def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: @@ -943,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool): end_callback=self._end_callback, cancel_callback=self._cancel_callback) for _ in range(num) ) + # TODO: Same deal as with the other meta tasks, provide proper cancellation handling! await gather(*start_coroutines) def start(self, num: int) -> str: diff --git a/tests/test_pool.py b/tests/test_pool.py index bbffdf8..912e16f 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -548,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase): n = 2 mock_semaphore_cls.return_value = semaphore = Semaphore(n) mock__get_map_end_callback.return_value = map_cb = MagicMock() - awaitable = 'totally an awaitable' - mock_star_function.side_effect = [awaitable, Exception(), awaitable] + awaitable1, awaitable2 = 'totally an awaitable', object() + mock_star_function.side_effect = [awaitable1, Exception(), awaitable2] arg1, arg2, bad = 123456789, 'function argument', None args = [arg1, bad, arg2] - group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 + grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) - # We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then - # acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked. + self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb)) + # We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked. self.assertTrue(semaphore.locked()) mock_semaphore_cls.assert_called_once_with(n) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) - mock__start_task.assert_has_awaits(2 * [ - call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) + mock__start_task.assert_has_awaits([ + call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb), + call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb), ]) mock_star_function.assert_has_calls([ call(mock_func, arg1, arg_stars=stars), @@ -572,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase): mock_semaphore_cls.reset_mock() mock__get_map_end_callback.reset_mock() mock__start_task.reset_mock() - mock_star_function.reset_mock() + mock_star_function.reset_mock(side_effect=True) - # With a CancelledError thrown while starting a task: - mock_semaphore_cls.return_value = semaphore = Semaphore(1) - mock_star_function.side_effect = CancelledError() - self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) - self.assertFalse(semaphore.locked()) + # With a CancelledError thrown while acquiring the semaphore: + mock_acquire = AsyncMock(side_effect=[True, CancelledError]) + mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire) + mock_star_function.return_value = mock_coroutine = MagicMock() + arg_it = iter(arg for arg in (arg1, arg2, FOO)) + self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb)) mock_semaphore_cls.assert_called_once_with(n) - mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) - mock__start_task.assert_not_called() - mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) + mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb) + mock_star_function.assert_has_calls([ + call(mock_func, arg1, arg_stars=stars), + call(mock_func, arg2, arg_stars=stars) + ]) + mock_acquire.assert_has_awaits([call(), call()]) + mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True, + end_callback=map_cb, cancel_callback=cancel_cb) + mock_coroutine.close.assert_called_once_with() + mock_semaphore.release.assert_not_called() + self.assertEqual(FOO, next(arg_it)) + + mock_acquire.reset_mock(side_effect=True) + mock_semaphore_cls.reset_mock() + mock__get_map_end_callback.reset_mock() + mock__start_task.reset_mock() + mock_star_function.reset_mock(side_effect=True) + + # With a CancelledError thrown while starting the task: + mock__start_task.side_effect = [None, CancelledError] + arg_it = iter(arg for arg in (arg1, arg2, FOO)) + self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb)) + mock_semaphore_cls.assert_called_once_with(n) + mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb) + mock_star_function.assert_has_calls([ + call(mock_func, arg1, arg_stars=stars), + call(mock_func, arg2, arg_stars=stars) + ]) + mock_acquire.assert_has_awaits([call(), call()]) + mock__start_task.assert_has_awaits(2 * [ + call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) + ]) + mock_coroutine.close.assert_called_once_with() + mock_semaphore.release.assert_called_once_with() + self.assertEqual(FOO, next(arg_it)) @patch.object(pool, 'create_task') @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)