Compare commits

...

3 Commits

7 changed files with 177 additions and 18 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.5.0 version = 0.5.1
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -37,10 +37,21 @@ class CLIENT_INFO:
class CMD: class CMD:
__slots__ = () __slots__ = ()
# Base commands:
CMD = 'command' CMD = 'command'
NAME = 'name' NAME = 'name'
POOL_SIZE = 'pool-size' POOL_SIZE = 'pool-size'
IS_LOCKED = 'is-locked'
LOCK = 'lock'
UNLOCK = 'unlock'
NUM_RUNNING = 'num-running' NUM_RUNNING = 'num-running'
NUM_CANCELLATIONS = 'num-cancellations'
NUM_ENDED = 'num-ended'
NUM_FINISHED = 'num-finished'
IS_FULL = 'is-full'
GET_GROUP_IDS = 'get-group-ids'
# Simple commands:
START = 'start' START = 'start'
STOP = 'stop' STOP = 'stop'
STOP_ALL = 'stop-all' STOP_ALL = 'stop-all'

View File

@ -178,23 +178,26 @@ class BaseTaskPool:
""" """
return self._enough_room.locked() return self._enough_room.locked()
def get_task_group_ids(self, group_name: str) -> Set[int]: def get_group_ids(self, *group_names: str) -> Set[int]:
""" """
Returns the set of IDs of all tasks in the specified group. Returns the set of IDs of all tasks in the specified groups.
Args: Args:
group_name: Must be a name of a task group that exists within the pool. *group_names: Each element must be a name of a task group that exists within the pool.
Returns: Returns:
Set of integers representing the task IDs belonging to the specified group. Set of integers representing the task IDs belonging to the specified groups.
Raises: Raises:
`InvalidGroupName` if no task group named `group_name` exists in the pool. `InvalidGroupName` if one of the specified `group_names` does not exist in the pool.
""" """
try: ids = set()
return set(self._task_groups[group_name]) for name in group_names:
except KeyError: try:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") ids.update(self._task_groups[name])
except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {name} exists in this pool.")
return ids
def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None, def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None,
ignore_lock: bool = False) -> None: ignore_lock: bool = False) -> None:

View File

@ -23,7 +23,7 @@ import logging
import json import json
from argparse import ArgumentError, HelpFormatter from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
@ -108,19 +108,36 @@ class ControlSession:
These include commands mapping to the following pool methods: These include commands mapping to the following pool methods:
- __str__ - __str__
- pool_size (get/set property) - pool_size (get/set property)
- is_locked
- lock & unlock
- num_running - num_running
""" """
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) cls: Type[BaseTaskPool] = self._pool.__class__
self._add_command(CMD.NAME, short_help=get_first_doc_line(cls.__str__))
self._add_command( self._add_command(
CMD.POOL_SIZE, CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.", short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter formatter_class=HelpFormatter
).add_optional_num_argument( ).add_optional_num_argument(
default=None, default=None,
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " help=f"If passed a number: {get_first_doc_line(cls.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" f"If omitted: {get_first_doc_line(cls.pool_size.fget)}"
)
self._add_command(CMD.IS_LOCKED, short_help=get_first_doc_line(cls.is_locked.fget))
self._add_command(CMD.LOCK, short_help=get_first_doc_line(cls.lock))
self._add_command(CMD.UNLOCK, short_help=get_first_doc_line(cls.unlock))
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(cls.num_running.fget))
self._add_command(CMD.NUM_CANCELLATIONS, short_help=get_first_doc_line(cls.num_cancellations.fget))
self._add_command(CMD.NUM_ENDED, short_help=get_first_doc_line(cls.num_ended.fget))
self._add_command(CMD.NUM_FINISHED, short_help=get_first_doc_line(cls.num_finished.fget))
self._add_command(CMD.IS_FULL, short_help=get_first_doc_line(cls.is_full.fget))
self._add_command(
CMD.GET_GROUP_IDS, short_help=get_first_doc_line(cls.get_group_ids)
).add_argument(
'group_name',
nargs='*',
help="Must be a name of a task group that exists within the pool."
) )
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
def _add_simple_commands(self) -> None: def _add_simple_commands(self) -> None:
""" """

View File

@ -0,0 +1,85 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.group_register` module.
"""
from asyncio.locks import Lock
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import group_register
FOO, BAR = 'foo', 'bar'
class TaskGroupRegisterTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.reg = group_register.TaskGroupRegister()
def test_init(self):
ids = [FOO, BAR, 1, 2]
reg = group_register.TaskGroupRegister(*ids)
self.assertSetEqual(set(ids), reg._ids)
self.assertIsInstance(reg._lock, Lock)
def test___contains__(self):
self.reg._ids = {1, 2, 3}
for i in self.reg._ids:
self.assertTrue(i in self.reg)
self.assertFalse(4 in self.reg)
@patch.object(group_register, 'iter', return_value=FOO)
def test___iter__(self, mock_iter: MagicMock):
self.assertEqual(FOO, self.reg.__iter__())
mock_iter.assert_called_once_with(self.reg._ids)
def test___len__(self):
self.reg._ids = [1, 2, 3, 4]
self.assertEqual(4, len(self.reg))
def test_add(self):
self.assertSetEqual(set(), self.reg._ids)
self.assertIsNone(self.reg.add(123))
self.assertSetEqual({123}, self.reg._ids)
def test_discard(self):
self.reg._ids = {123}
self.assertIsNone(self.reg.discard(0))
self.assertIsNone(self.reg.discard(999))
self.assertIsNone(self.reg.discard(123))
self.assertSetEqual(set(), self.reg._ids)
async def test_acquire(self):
self.assertFalse(self.reg._lock.locked())
await self.reg.acquire()
self.assertTrue(self.reg._lock.locked())
def test_release(self):
self.reg._lock._locked = True
self.assertTrue(self.reg._lock.locked())
self.reg.release()
self.assertFalse(self.reg._lock.locked())
async def test_contextmanager(self):
self.assertFalse(self.reg._lock.locked())
async with self.reg as nothing:
self.assertIsNone(nothing)
self.assertTrue(self.reg._lock.locked())
self.assertFalse(self.reg._lock.locked())

View File

@ -163,12 +163,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
def test_is_full(self): def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test_get_task_group_ids(self): def test_get_group_ids(self):
group_name, ids = 'abcdef', [1, 2, 3] group_name, ids = 'abcdef', [1, 2, 3]
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids)) self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name)) self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name))
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.get_task_group_ids('something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed = True

View File

@ -0,0 +1,43 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.queue_context` module.
"""
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch
from asyncio_taskpool.queue_context import Queue
class QueueTestCase(IsolatedAsyncioTestCase):
def test_item_processed(self):
queue = Queue()
queue._unfinished_tasks = 1000
queue.item_processed()
self.assertEqual(999, queue._unfinished_tasks)
@patch.object(Queue, 'item_processed')
async def test_contextmanager(self, mock_item_processed: MagicMock):
queue = Queue()
item = 'foo'
queue.put_nowait(item)
async with queue as item_from_queue:
self.assertEqual(item, item_from_queue)
mock_item_processed.assert_not_called()
mock_item_processed.assert_called_once_with()