diff --git a/docs/source/pages/pool.rst b/docs/source/pages/pool.rst index 5034f74..6b57883 100644 --- a/docs/source/pages/pool.rst +++ b/docs/source/pages/pool.rst @@ -81,7 +81,7 @@ By contrast, here is how you would do it with a task pool: ... pool = TaskPool() - group_name = await pool.apply(queue_worker_function, args=(q_in, q_out), num=5) + group_name = pool.apply(queue_worker_function, args=(q_in, q_out), num=5) ... pool.cancel_group(group_name) ... @@ -141,7 +141,7 @@ Or we could use a task pool: async def main(): ... pool = TaskPool() - await pool.map(another_worker_function, data_iterator, num_concurrent=5) + pool.map(another_worker_function, data_iterator, num_concurrent=5) ... await pool.gather_and_close() diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 5a998b2..cab3000 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -409,7 +409,7 @@ class BaseTaskPool: continue log.debug("%s cancelled tasks from group %s", str(self), group_name) - async def cancel_group(self, group_name: str, msg: str = None) -> None: + def cancel_group(self, group_name: str, msg: str = None) -> None: """ Cancels an entire group of tasks. @@ -431,11 +431,10 @@ class BaseTaskPool: group_reg = self._task_groups.pop(group_name) except KeyError: raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") - async with group_reg: - self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) + self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) log.debug("%s forgot task group %s", str(self), group_name) - async def cancel_all(self, msg: str = None) -> None: + def cancel_all(self, msg: str = None) -> None: """ Cancels all tasks still running within the pool (including meta tasks). @@ -449,8 +448,7 @@ class BaseTaskPool: log.warning("%s cancelling all tasks!", str(self)) while self._task_groups: group_name, group_reg = self._task_groups.popitem() - async with group_reg: - self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) + self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) def _pop_ended_meta_tasks(self) -> Set[Task]: """ @@ -598,8 +596,8 @@ class TaskPool(BaseTaskPool): await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback, cancel_callback=cancel_callback) for _ in range(num))) - async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, - group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: + def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ Creates tasks with the supplied arguments to be run in the pool. @@ -646,11 +644,10 @@ class TaskPool(BaseTaskPool): self._check_start(function=func) if group_name is None: group_name = self._generate_group_name('apply', func) - group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister()) - async with group_reg: - meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) - meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num, - end_callback=end_callback, cancel_callback=cancel_callback))) + self._task_groups.setdefault(group_name, TaskGroupRegister()) + meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) + meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num, + end_callback=end_callback, cancel_callback=cancel_callback))) return group_name @staticmethod @@ -711,8 +708,8 @@ class TaskPool(BaseTaskPool): str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg)) map_semaphore.release() - async def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, - end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: + def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: """ Creates tasks in the pool with arguments from the supplied iterable. @@ -760,14 +757,13 @@ class TaskPool(BaseTaskPool): raise ValueError("`num_concurrent` must be a positive integer.") if group_name in self._task_groups.keys(): raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") - self._task_groups[group_name] = group_reg = TaskGroupRegister() - async with group_reg: - meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) - meta_tasks.add(create_task(self._arg_consumer(group_name, num_concurrent, func, arg_iter, arg_stars, - end_callback=end_callback, cancel_callback=cancel_callback))) + self._task_groups[group_name] = TaskGroupRegister() + meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) + meta_tasks.add(create_task(self._arg_consumer(group_name, num_concurrent, func, arg_iter, arg_stars, + end_callback=end_callback, cancel_callback=cancel_callback))) - async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None, - end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: + def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ A task-based equivalent of the `multiprocessing.pool.Pool.map` method. @@ -819,12 +815,12 @@ class TaskPool(BaseTaskPool): """ if group_name is None: group_name = self._generate_group_name('map', func) - await self._map(group_name, num_concurrent, func, arg_iter, 0, - end_callback=end_callback, cancel_callback=cancel_callback) + self._map(group_name, num_concurrent, func, arg_iter, 0, + end_callback=end_callback, cancel_callback=cancel_callback) return group_name - async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, - group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: + def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None, + end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as positional arguments to the function. @@ -836,13 +832,12 @@ class TaskPool(BaseTaskPool): """ if group_name is None: group_name = self._generate_group_name('starmap', func) - await self._map(group_name, num_concurrent, func, args_iter, 1, - end_callback=end_callback, cancel_callback=cancel_callback) + self._map(group_name, num_concurrent, func, args_iter, 1, + end_callback=end_callback, cancel_callback=cancel_callback) return group_name - async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1, - group_name: str = None, end_callback: EndCB = None, - cancel_callback: CancelCB = None) -> str: + def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1, + group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: """ Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as keyword-arguments to the function. @@ -854,8 +849,8 @@ class TaskPool(BaseTaskPool): """ if group_name is None: group_name = self._generate_group_name('doublestarmap', func) - await self._map(group_name, num_concurrent, func, kwargs_iter, 2, - end_callback=end_callback, cancel_callback=cancel_callback) + self._map(group_name, num_concurrent, func, kwargs_iter, 2, + end_callback=end_callback, cancel_callback=cancel_callback) return group_name diff --git a/tests/test_pool.py b/tests/test_pool.py index ec3a4c6..f4bb175 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -351,29 +351,24 @@ class BaseTaskPoolTestCase(CommonTestCase): mock_cancel.assert_called_once_with(msg=FOO) @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') - async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): - mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock() - mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit) - self.task_pool._task_groups[FOO] = mock_group_reg + def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): + self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock() with self.assertRaises(exceptions.InvalidGroupName): - await self.task_pool.cancel_group(BAR) + self.task_pool.cancel_group(BAR) mock__cancel_and_remove_all_from_group.assert_not_called() - mock_grp_aenter.assert_not_called() - mock_grp_aexit.assert_not_called() - self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR)) + self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR)) + self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) - mock_grp_aenter.assert_awaited_once_with() - mock_grp_aexit.assert_awaited_once() @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') - async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): - mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock() - mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit) - self.task_pool._task_groups[BAR] = mock_group_reg - self.assertIsNone(await self.task_pool.cancel_all(FOO)) - mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO) - mock_grp_aenter.assert_awaited_once_with() - mock_grp_aexit.assert_awaited_once() + def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): + mock_group_reg = MagicMock() + self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg} + self.assertIsNone(self.task_pool.cancel_all('msg')) + mock__cancel_and_remove_all_from_group.assert_has_calls([ + call(BAR, mock_group_reg, msg='msg'), + call(FOO, mock_group_reg, msg='msg') + ]) def test__pop_ended_meta_tasks(self): mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True) @@ -486,8 +481,8 @@ class TaskPoolTestCase(CommonTestCase): @patch.object(pool, 'TaskGroupRegister') @patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.BaseTaskPool, '_check_start') - async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, - mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): + def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, + mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): mock__generate_group_name.return_value = generated_name = 'name 123' mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock__apply_num.return_value = mock_apply_coroutine = object() @@ -501,25 +496,21 @@ class TaskPoolTestCase(CommonTestCase): self.assertEqual(_group_name, _output) mock__check_start.assert_called_once_with(function=mock_func) self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) - mock_group_reg.__aenter__.assert_awaited_once_with() mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_callback=end_cb, cancel_callback=cancel_cb) mock_create_task.assert_called_once_with(mock_apply_coroutine) - mock_group_reg.__aexit__.assert_awaited_once() self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) - output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) + output = self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) check_assertions(group_name, output) mock__generate_group_name.assert_not_called() mock__check_start.reset_mock() self.task_pool._task_groups.clear() - mock_group_reg.__aenter__.reset_mock() mock__apply_num.reset_mock() mock_create_task.reset_mock() - mock_group_reg.__aexit__.reset_mock() - output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) + output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) check_assertions(generated_name, output) mock__generate_group_name.assert_called_once_with('apply', mock_func) @@ -581,8 +572,8 @@ class TaskPoolTestCase(CommonTestCase): @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) @patch.object(pool, 'TaskGroupRegister') @patch.object(pool.BaseTaskPool, '_check_start') - async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock, - mock_create_task: MagicMock): + def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock, + mock_create_task: MagicMock): mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock__arg_consumer.return_value = fake_consumer = object() mock_create_task.return_value = fake_task = object() @@ -592,7 +583,7 @@ class TaskPoolTestCase(CommonTestCase): end_cb, cancel_cb = MagicMock(), MagicMock() with self.assertRaises(ValueError): - await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) + self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) mock__check_start.assert_called_once_with(function=func) mock__check_start.reset_mock() @@ -601,82 +592,80 @@ class TaskPoolTestCase(CommonTestCase): self.task_pool._task_groups = {group_name: MagicMock()} with self.assertRaises(exceptions.InvalidGroupName): - await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) + self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) mock__check_start.assert_called_once_with(function=func) mock__check_start.reset_mock() self.task_pool._task_groups.clear() - self.assertIsNone(await self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)) + self.assertIsNone(self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)) mock__check_start.assert_called_once_with(function=func) mock_reg_cls.assert_called_once_with() self.task_pool._task_groups[group_name] = mock_group_reg - mock_group_reg.__aenter__.assert_awaited_once_with() mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars, end_callback=end_cb, cancel_callback=cancel_cb) mock_create_task.assert_called_once_with(fake_consumer) self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) - mock_group_reg.__aexit__.assert_awaited_once() @patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_generate_group_name') - async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + def test_map(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb) + output = self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb) self.assertEqual(group_name, output) - mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, arg_iter, 0, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, arg_iter, 0, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_not_called() mock__map.reset_mock() - output = await self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb) + output = self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb) self.assertEqual(generated_name, output) - mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_called_once_with('map', mock_func) @patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_generate_group_name') - async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb) + output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb) self.assertEqual(group_name, output) - mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, args_iter, 1, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, args_iter, 1, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_not_called() mock__map.reset_mock() - output = await self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb) + output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb) self.assertEqual(generated_name, output) - mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, args_iter, 1, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, args_iter, 1, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_called_once_with('starmap', mock_func) @patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_generate_group_name') - async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): + async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock_func = MagicMock() kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR end_cb, cancel_cb = MagicMock(), MagicMock() - output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb) + output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb) self.assertEqual(group_name, output) - mock__map.assert_awaited_once_with(group_name, num_concurrent, mock_func, kw_iter, 2, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, kw_iter, 2, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_not_called() mock__map.reset_mock() - output = await self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb) + output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb) self.assertEqual(generated_name, output) - mock__map.assert_awaited_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2, - end_callback=end_cb, cancel_callback=cancel_cb) + mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2, + end_callback=end_cb, cancel_callback=cancel_cb) mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func) diff --git a/usage/USAGE.md b/usage/USAGE.md index 7f227ee..3ce38e9 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -122,7 +122,7 @@ async def main() -> None: pool = TaskPool(3) # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments). print("> Called `apply`") - await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) + pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) # Let the tasks work for a bit. await asyncio.sleep(1.5) # Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different @@ -134,7 +134,7 @@ async def main() -> None: # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) # only once there is room in the pool and no more than one other task of these new ones is running. args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] - await pool.starmap(other_work, args_list, num_concurrent=2) + pool.starmap(other_work, args_list, num_concurrent=2) print("> Called `starmap`") # We block, until all tasks have ended. print("> Calling `gather_and_close`...")