diff --git a/src/asyncio_taskpool/internals/types.py b/src/asyncio_taskpool/internals/types.py index 5f0d62a..69c6816 100644 --- a/src/asyncio_taskpool/internals/types.py +++ b/src/asyncio_taskpool/internals/types.py @@ -23,7 +23,7 @@ This module should **not** be considered part of the public API. from asyncio.streams import StreamReader, StreamWriter from pathlib import Path -from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union +from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union T = TypeVar('T') @@ -32,7 +32,7 @@ ArgsT = Iterable[Any] KwArgsT = Mapping[str, Any] AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] -CoroutineFunc = Callable[[...], Awaitable[Any]] +CoroutineFunc = Callable[[...], Coroutine] EndCB = Callable CancelCB = Callable diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index ae1ea92..9355802 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -528,6 +528,8 @@ class BaseTaskPool: self._tasks_cancelled.clear() self._tasks_running.clear() self._closed = True + # TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will + # await it to allow blocking until a closing command comes from a server. class TaskPool(BaseTaskPool): @@ -566,36 +568,51 @@ class TaskPool(BaseTaskPool): return name i += 1 - async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, - num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: + async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, + num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ - Creates a coroutine with the supplied arguments and runs it as a new task in the pool. + Creates coroutines with the supplied arguments and runs them as new tasks in the pool. This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks. Args: group_name: - Name of the task group to add the new task to. + Name of the task group to add the new tasks to. func: - The coroutine function to be run as a task within the task pool. + The coroutine function to be run in `num` tasks within the task pool. args (optional): - The positional arguments to pass into the function call. + The positional arguments to pass into each function call. kwargs (optional): - The keyword-arguments to pass into the function call. + The keyword-arguments to pass into each function call. num (optional): The number of tasks to spawn with the specified parameters. end_callback (optional): - A callback to execute after the task has ended. + A callback to execute after each task has ended. It is run with the task's ID as its only positional argument. cancel_callback (optional): - A callback to execute after cancellation of the task. + A callback to execute after cancellation of each task. It is run with the task's ID as its only positional argument. """ if kwargs is None: kwargs = {} - # TODO: Add exception logging - await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback, - cancel_callback=cancel_callback) for _ in range(num))) + for i in range(num): + try: + coroutine = func(*args, **kwargs) + except Exception as e: + # This means there was probably something wrong with the function arguments. + log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)", + str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs)) + continue + try: + await self._start_task(coroutine, group_name=group_name, end_callback=end_callback, + cancel_callback=cancel_callback) + except CancelledError: + # Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any + # more tasks and can return immediately. + log.debug("Cancelled spawning tasks in group '%s' after %s out of %s tasks have been spawned", + group_name, i, num) + coroutine.close() + return def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: @@ -650,8 +667,8 @@ class TaskPool(BaseTaskPool): raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") self._task_groups.setdefault(group_name, TaskGroupRegister()) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) - meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num, - end_callback=end_callback, cancel_callback=cancel_callback))) + meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num, + end_callback=end_callback, cancel_callback=cancel_callback))) return group_name @staticmethod @@ -696,6 +713,8 @@ class TaskPool(BaseTaskPool): # When the number of running tasks spawned by this method reaches the specified maximum, # this next line will block, until one of them ends and releases the semaphore. await map_semaphore.acquire() + # TODO: Clean up exception handling/logging. Cancellation can also occur while awaiting the semaphore. + # Wrap `star_function` call in a separate `try` block (similar to `_apply_spawner`). try: await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name, ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback) diff --git a/tests/test_pool.py b/tests/test_pool.py index 689c5a7..bbffdf8 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -455,49 +455,65 @@ class TaskPoolTestCase(CommonTestCase): self.assertEqual(expected_output, output) @patch.object(pool.TaskPool, '_start_task') - async def test__apply_num(self, mock__start_task: AsyncMock): - group_name = FOO + BAR - mock_awaitable = object() - mock_func = MagicMock(return_value=mock_awaitable) - args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 + async def test__apply_spawner(self, mock__start_task: AsyncMock): + grp_name = FOO + BAR + mock_awaitable1, mock_awaitable2 = object(), object() + mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func') + args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 end_cb, cancel_cb = MagicMock(), MagicMock() - self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)) - mock_func.assert_has_calls(3 * [call(*args, **kwargs)]) - mock__start_task.assert_has_awaits(3 * [ - call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) + self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb)) + mock_func.assert_has_calls(num * [call(*args, **kw)]) + mock__start_task.assert_has_awaits([ + call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb), + call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb), ]) - mock_func.reset_mock() + mock_func.reset_mock(side_effect=True) mock__start_task.reset_mock() - self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb)) - mock_func.assert_has_calls(num * [call(*args)]) - mock__start_task.assert_has_awaits(num * [ - call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) + # Simulate cancellation while the second task is being started. + mock__start_task.side_effect = [None, CancelledError, None] + mock_coroutine_to_close = MagicMock() + mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called'] + self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb)) + mock_func.assert_has_calls(2 * [call(*args)]) + mock__start_task.assert_has_awaits([ + call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb), + call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb), ]) + mock_coroutine_to_close.close.assert_called_once_with() @patch.object(pool, 'create_task') - @patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock()) + @patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock()) @patch.object(pool, 'TaskGroupRegister') @patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.BaseTaskPool, '_check_start') def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, - mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): + mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock): mock__generate_group_name.return_value = generated_name = 'name 123' mock_group_reg = set_up_mock_group_register(mock_reg_cls) - mock__apply_num.return_value = mock_apply_coroutine = object() + mock__apply_spawner.return_value = mock_apply_coroutine = object() mock_create_task.return_value = fake_task = object() mock_func, num, group_name = MagicMock(), 3, FOO + BAR args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} end_cb, cancel_cb = MagicMock(), MagicMock() + + self.task_pool._task_groups = {group_name: 'causes error'} + with self.assertRaises(exceptions.InvalidGroupName): + self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) + mock__check_start.assert_called_once_with(function=mock_func) + mock__apply_spawner.assert_not_called() + mock_create_task.assert_not_called() + + mock__check_start.reset_mock() self.task_pool._task_groups = {} def check_assertions(_group_name, _output): self.assertEqual(_group_name, _output) mock__check_start.assert_called_once_with(function=mock_func) self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) - mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num, + end_callback=end_cb, cancel_callback=cancel_cb) mock_create_task.assert_called_once_with(mock_apply_coroutine) self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) @@ -507,7 +523,7 @@ class TaskPoolTestCase(CommonTestCase): mock__check_start.reset_mock() self.task_pool._task_groups.clear() - mock__apply_num.reset_mock() + mock__apply_spawner.reset_mock() mock_create_task.reset_mock() output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) @@ -695,6 +711,7 @@ class SimpleTaskPoolTestCase(CommonTestCase): def tearDown(self) -> None: self.base_class_init_patcher.stop() + super().tearDown() def test_init(self): self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)