From 7c66604ad0b95eff7a8cf81d43095e5c483114ca Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Mon, 7 Mar 2022 14:21:24 +0100 Subject: [PATCH] renamed method `get_task_group_ids` and extended to accept any number of group names --- setup.cfg | 2 +- src/asyncio_taskpool/pool.py | 21 ++++++++++++--------- tests/test_pool.py | 6 +++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/setup.cfg b/setup.cfg index 60a81b4..8474f40 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.5.0 +version = 0.5.1 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 9f0abfc..183cbc7 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -178,23 +178,26 @@ class BaseTaskPool: """ return self._enough_room.locked() - def get_task_group_ids(self, group_name: str) -> Set[int]: + def get_group_ids(self, *group_names: str) -> Set[int]: """ - Returns the set of IDs of all tasks in the specified group. + Returns the set of IDs of all tasks in the specified groups. Args: - group_name: Must be a name of a task group that exists within the pool. + *group_names: Each element must be a name of a task group that exists within the pool. Returns: - Set of integers representing the task IDs belonging to the specified group. + Set of integers representing the task IDs belonging to the specified groups. Raises: - `InvalidGroupName` if no task group named `group_name` exists in the pool. + `InvalidGroupName` if one of the specified `group_names` does not exist in the pool. """ - try: - return set(self._task_groups[group_name]) - except KeyError: - raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") + ids = set() + for name in group_names: + try: + ids.update(self._task_groups[name]) + except KeyError: + raise exceptions.InvalidGroupName(f"No task group named {name} exists in this pool.") + return ids def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None, ignore_lock: bool = False) -> None: diff --git a/tests/test_pool.py b/tests/test_pool.py index aef5fc0..e467d57 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -163,12 +163,12 @@ class BaseTaskPoolTestCase(CommonTestCase): def test_is_full(self): self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) - def test_get_task_group_ids(self): + def test_get_group_ids(self): group_name, ids = 'abcdef', [1, 2, 3] self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids)) - self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name)) + self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name)) with self.assertRaises(exceptions.InvalidGroupName): - self.task_pool.get_task_group_ids('something else') + self.task_pool.get_group_ids(group_name, 'something else') async def test__check_start(self): self.task_pool._closed = True