diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 6f82cc9..ed8d9c4 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -625,7 +625,7 @@ class TaskPool(BaseTaskPool): # This means there was probably something wrong with the function arguments. log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)", str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs)) - continue + continue # TODO: Consider returning instead of continuing try: await self._start_task(coroutine, group_name=group_name, end_callback=end_callback, cancel_callback=cancel_callback) @@ -962,13 +962,24 @@ class SimpleTaskPool(BaseTaskPool): async def _start_num(self, num: int, group_name: str) -> None: """Starts `num` new tasks in group `group_name`.""" - start_coroutines = ( - self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name, - end_callback=self._end_callback, cancel_callback=self._cancel_callback) - for _ in range(num) - ) - # TODO: Same deal as with the other meta tasks, provide proper cancellation handling! - await gather(*start_coroutines) + for i in range(num): + try: + coroutine = self._func(*self._args, **self._kwargs) + except Exception as e: + # This means there was probably something wrong with the function arguments. + log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)", + str(e.__class__.__name__), str(self), self._func.__name__, + repr(self._args), repr(self._kwargs)) + continue # TODO: Consider returning instead of continuing + try: + await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback, + cancel_callback=self._cancel_callback) + except CancelledError: + # Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any + # more tasks and can return immediately. + log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num) + coroutine.close() + return def start(self, num: int) -> str: """ diff --git a/tests/test_pool.py b/tests/test_pool.py index 26fdbe5..dbcc1af 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -790,18 +790,30 @@ class SimpleTaskPoolTestCase(CommonTestCase): @patch.object(pool.SimpleTaskPool, '_start_task') async def test__start_num(self, mock__start_task: AsyncMock): - fake_coroutine = object() - self.task_pool._func = MagicMock(return_value=fake_coroutine) - num = 3 group_name = FOO + BAR + 'abc' + mock_awaitable1, mock_awaitable2 = object(), object() + self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func') + num = 3 self.assertIsNone(await self.task_pool._start_num(num, group_name)) - self.task_pool._func.assert_has_calls(num * [ - call(*self.task_pool._args, **self.task_pool._kwargs) - ]) - mock__start_task.assert_has_awaits(num * [ - call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback, - cancel_callback=self.task_pool._cancel_callback) - ]) + self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)]) + call_kw = { + 'group_name': group_name, + 'end_callback': self.task_pool._end_callback, + 'cancel_callback': self.task_pool._cancel_callback + } + mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)]) + + self.task_pool._func.reset_mock(side_effect=True) + mock__start_task.reset_mock() + + # Simulate cancellation while the second task is being started. + mock__start_task.side_effect = [None, CancelledError, None] + mock_coroutine_to_close = MagicMock() + self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called'] + self.assertIsNone(await self.task_pool._start_num(num, group_name)) + self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)]) + mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)]) + mock_coroutine_to_close.close.assert_called_once_with() @patch.object(pool, 'create_task') @patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())