renamed method get_task_group_ids and extended to accept any number of group names

This commit is contained in:
Daniil Fajnberg 2022-03-07 14:21:24 +01:00
parent 287906a218
commit 7c66604ad0
3 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.5.0 version = 0.5.1
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -178,23 +178,26 @@ class BaseTaskPool:
""" """
return self._enough_room.locked() return self._enough_room.locked()
def get_task_group_ids(self, group_name: str) -> Set[int]: def get_group_ids(self, *group_names: str) -> Set[int]:
""" """
Returns the set of IDs of all tasks in the specified group. Returns the set of IDs of all tasks in the specified groups.
Args: Args:
group_name: Must be a name of a task group that exists within the pool. *group_names: Each element must be a name of a task group that exists within the pool.
Returns: Returns:
Set of integers representing the task IDs belonging to the specified group. Set of integers representing the task IDs belonging to the specified groups.
Raises: Raises:
`InvalidGroupName` if no task group named `group_name` exists in the pool. `InvalidGroupName` if one of the specified `group_names` does not exist in the pool.
""" """
try: ids = set()
return set(self._task_groups[group_name]) for name in group_names:
except KeyError: try:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") ids.update(self._task_groups[name])
except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {name} exists in this pool.")
return ids
def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None, def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None,
ignore_lock: bool = False) -> None: ignore_lock: bool = False) -> None:

View File

@ -163,12 +163,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
def test_is_full(self): def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test_get_task_group_ids(self): def test_get_group_ids(self):
group_name, ids = 'abcdef', [1, 2, 3] group_name, ids = 'abcdef', [1, 2, 3]
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids)) self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name)) self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name))
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.get_task_group_ids('something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed = True