generated from daniil-berg/boilerplate-py
	Compare commits
	
		
			3 Commits
		
	
	
		
			v0.5.0-lw
			...
			d05f84b2c3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d05f84b2c3 | |||
| 7c66604ad0 | |||
| 287906a218 | 
@@ -1,6 +1,6 @@
 | 
			
		||||
[metadata]
 | 
			
		||||
name = asyncio-taskpool
 | 
			
		||||
version = 0.5.0
 | 
			
		||||
version = 0.5.1
 | 
			
		||||
author = Daniil Fajnberg
 | 
			
		||||
author_email = mail@daniil.fajnberg.de
 | 
			
		||||
description = Dynamically manage pools of asyncio tasks
 | 
			
		||||
 
 | 
			
		||||
@@ -37,10 +37,21 @@ class CLIENT_INFO:
 | 
			
		||||
 | 
			
		||||
class CMD:
 | 
			
		||||
    __slots__ = ()
 | 
			
		||||
    # Base commands:
 | 
			
		||||
    CMD = 'command'
 | 
			
		||||
    NAME = 'name'
 | 
			
		||||
    POOL_SIZE = 'pool-size'
 | 
			
		||||
    IS_LOCKED = 'is-locked'
 | 
			
		||||
    LOCK = 'lock'
 | 
			
		||||
    UNLOCK = 'unlock'
 | 
			
		||||
    NUM_RUNNING = 'num-running'
 | 
			
		||||
    NUM_CANCELLATIONS = 'num-cancellations'
 | 
			
		||||
    NUM_ENDED = 'num-ended'
 | 
			
		||||
    NUM_FINISHED = 'num-finished'
 | 
			
		||||
    IS_FULL = 'is-full'
 | 
			
		||||
    GET_GROUP_IDS = 'get-group-ids'
 | 
			
		||||
 | 
			
		||||
    # Simple commands:
 | 
			
		||||
    START = 'start'
 | 
			
		||||
    STOP = 'stop'
 | 
			
		||||
    STOP_ALL = 'stop-all'
 | 
			
		||||
 
 | 
			
		||||
@@ -178,23 +178,26 @@ class BaseTaskPool:
 | 
			
		||||
        """
 | 
			
		||||
        return self._enough_room.locked()
 | 
			
		||||
 | 
			
		||||
    def get_task_group_ids(self, group_name: str) -> Set[int]:
 | 
			
		||||
    def get_group_ids(self, *group_names: str) -> Set[int]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the set of IDs of all tasks in the specified group.
 | 
			
		||||
        Returns the set of IDs of all tasks in the specified groups.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            group_name: Must be a name of a task group that exists within the pool.
 | 
			
		||||
            *group_names: Each element must be a name of a task group that exists within the pool.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Set of integers representing the task IDs belonging to the specified group.
 | 
			
		||||
            Set of integers representing the task IDs belonging to the specified groups.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            `InvalidGroupName` if no task group named `group_name` exists in the pool.
 | 
			
		||||
            `InvalidGroupName` if one of the specified `group_names` does not exist in the pool.
 | 
			
		||||
        """
 | 
			
		||||
        ids = set()
 | 
			
		||||
        for name in group_names:
 | 
			
		||||
            try:
 | 
			
		||||
            return set(self._task_groups[group_name])
 | 
			
		||||
                ids.update(self._task_groups[name])
 | 
			
		||||
            except KeyError:
 | 
			
		||||
            raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
 | 
			
		||||
                raise exceptions.InvalidGroupName(f"No task group named {name} exists in this pool.")
 | 
			
		||||
        return ids
 | 
			
		||||
 | 
			
		||||
    def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None,
 | 
			
		||||
                     ignore_lock: bool = False) -> None:
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ import logging
 | 
			
		||||
import json
 | 
			
		||||
from argparse import ArgumentError, HelpFormatter
 | 
			
		||||
from asyncio.streams import StreamReader, StreamWriter
 | 
			
		||||
from typing import Callable, Optional, Union, TYPE_CHECKING
 | 
			
		||||
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
 | 
			
		||||
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
 | 
			
		||||
@@ -108,19 +108,36 @@ class ControlSession:
 | 
			
		||||
        These include commands mapping to the following pool methods:
 | 
			
		||||
            - __str__
 | 
			
		||||
            - pool_size (get/set property)
 | 
			
		||||
            - is_locked
 | 
			
		||||
            - lock & unlock
 | 
			
		||||
            - num_running
 | 
			
		||||
        """
 | 
			
		||||
        self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
 | 
			
		||||
        cls: Type[BaseTaskPool] = self._pool.__class__
 | 
			
		||||
        self._add_command(CMD.NAME, short_help=get_first_doc_line(cls.__str__))
 | 
			
		||||
        self._add_command(
 | 
			
		||||
            CMD.POOL_SIZE, 
 | 
			
		||||
            short_help="Get/set the maximum number of tasks in the pool.", 
 | 
			
		||||
            formatter_class=HelpFormatter
 | 
			
		||||
        ).add_optional_num_argument(
 | 
			
		||||
            default=None,
 | 
			
		||||
            help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
 | 
			
		||||
                 f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
 | 
			
		||||
            help=f"If passed a number: {get_first_doc_line(cls.pool_size.fset)} "
 | 
			
		||||
                 f"If omitted: {get_first_doc_line(cls.pool_size.fget)}"
 | 
			
		||||
        )
 | 
			
		||||
        self._add_command(CMD.IS_LOCKED, short_help=get_first_doc_line(cls.is_locked.fget))
 | 
			
		||||
        self._add_command(CMD.LOCK, short_help=get_first_doc_line(cls.lock))
 | 
			
		||||
        self._add_command(CMD.UNLOCK, short_help=get_first_doc_line(cls.unlock))
 | 
			
		||||
        self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(cls.num_running.fget))
 | 
			
		||||
        self._add_command(CMD.NUM_CANCELLATIONS, short_help=get_first_doc_line(cls.num_cancellations.fget))
 | 
			
		||||
        self._add_command(CMD.NUM_ENDED, short_help=get_first_doc_line(cls.num_ended.fget))
 | 
			
		||||
        self._add_command(CMD.NUM_FINISHED, short_help=get_first_doc_line(cls.num_finished.fget))
 | 
			
		||||
        self._add_command(CMD.IS_FULL, short_help=get_first_doc_line(cls.is_full.fget))
 | 
			
		||||
        self._add_command(
 | 
			
		||||
            CMD.GET_GROUP_IDS, short_help=get_first_doc_line(cls.get_group_ids)
 | 
			
		||||
        ).add_argument(
 | 
			
		||||
            'group_name',
 | 
			
		||||
            nargs='*',
 | 
			
		||||
            help="Must be a name of a task group that exists within the pool."
 | 
			
		||||
        )
 | 
			
		||||
        self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
 | 
			
		||||
 | 
			
		||||
    def _add_simple_commands(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
@@ -227,11 +244,51 @@ class ControlSession:
 | 
			
		||||
            log.debug("%s requests setting pool size to %s", self._client_class_name, num)
 | 
			
		||||
            await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_is_locked(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `is_locked` property of any task pool class."""
 | 
			
		||||
        log.debug("%s checks locked status", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.is_locked.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_lock(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `lock` method of any task pool class."""
 | 
			
		||||
        log.debug("%s requests locking the pool", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.lock)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_unlock(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `unlock` method of any task pool class."""
 | 
			
		||||
        log.debug("%s requests unlocking the pool", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.unlock)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_num_running(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `num_running` property of any task pool class."""
 | 
			
		||||
        log.debug("%s requests number of running tasks", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_num_cancellations(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `num_cancellations` property of any task pool class."""
 | 
			
		||||
        log.debug("%s requests number of cancelled tasks", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.num_cancellations.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_num_ended(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `num_ended` property of any task pool class."""
 | 
			
		||||
        log.debug("%s requests number of ended tasks", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.num_ended.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_num_finished(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `num_finished` property of any task pool class."""
 | 
			
		||||
        log.debug("%s requests number of finished tasks", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.num_finished.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_is_full(self, **_kwargs) -> None:
 | 
			
		||||
        """Maps to the `is_full` property of any task pool class."""
 | 
			
		||||
        log.debug("%s checks full status", self._client_class_name)
 | 
			
		||||
        await self._write_function_output(self._pool.__class__.is_full.fget, self._pool)
 | 
			
		||||
 | 
			
		||||
    async def _cmd_get_group_ids(self, **kwargs) -> None:
 | 
			
		||||
        """Maps to the `get_group_ids` method of any task pool class."""
 | 
			
		||||
        log.debug("%s requests task ids for groups %s", self._client_class_name, kwargs['group_name'])
 | 
			
		||||
        await self._write_function_output(self._pool.get_group_ids, *kwargs['group_name'])
 | 
			
		||||
 | 
			
		||||
    async def _cmd_start(self, **kwargs) -> None:
 | 
			
		||||
        """Maps to the `start` method of the `SimpleTaskPool` class."""
 | 
			
		||||
        num = kwargs[NUM]
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										85
									
								
								tests/test_group_register.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								tests/test_group_register.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
			
		||||
__author__ = "Daniil Fajnberg"
 | 
			
		||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
 | 
			
		||||
__license__ = """GNU LGPLv3.0
 | 
			
		||||
 | 
			
		||||
This file is part of asyncio-taskpool.
 | 
			
		||||
 | 
			
		||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
 | 
			
		||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
 | 
			
		||||
 | 
			
		||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; 
 | 
			
		||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
 | 
			
		||||
See the GNU Lesser General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. 
 | 
			
		||||
If not, see <https://www.gnu.org/licenses/>."""
 | 
			
		||||
 | 
			
		||||
__doc__ = """
 | 
			
		||||
Unittests for the `asyncio_taskpool.group_register` module.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from asyncio.locks import Lock
 | 
			
		||||
from unittest import IsolatedAsyncioTestCase
 | 
			
		||||
from unittest.mock import AsyncMock, MagicMock, patch
 | 
			
		||||
 | 
			
		||||
from asyncio_taskpool import group_register
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FOO, BAR = 'foo', 'bar'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaskGroupRegisterTestCase(IsolatedAsyncioTestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        self.reg = group_register.TaskGroupRegister()
 | 
			
		||||
 | 
			
		||||
    def test_init(self):
 | 
			
		||||
        ids = [FOO, BAR, 1, 2]
 | 
			
		||||
        reg = group_register.TaskGroupRegister(*ids)
 | 
			
		||||
        self.assertSetEqual(set(ids), reg._ids)
 | 
			
		||||
        self.assertIsInstance(reg._lock, Lock)
 | 
			
		||||
 | 
			
		||||
    def test___contains__(self):
 | 
			
		||||
        self.reg._ids = {1, 2, 3}
 | 
			
		||||
        for i in self.reg._ids:
 | 
			
		||||
            self.assertTrue(i in self.reg)
 | 
			
		||||
        self.assertFalse(4 in self.reg)
 | 
			
		||||
 | 
			
		||||
    @patch.object(group_register, 'iter', return_value=FOO)
 | 
			
		||||
    def test___iter__(self, mock_iter: MagicMock):
 | 
			
		||||
        self.assertEqual(FOO, self.reg.__iter__())
 | 
			
		||||
        mock_iter.assert_called_once_with(self.reg._ids)
 | 
			
		||||
 | 
			
		||||
    def test___len__(self):
 | 
			
		||||
        self.reg._ids = [1, 2, 3, 4]
 | 
			
		||||
        self.assertEqual(4, len(self.reg))
 | 
			
		||||
 | 
			
		||||
    def test_add(self):
 | 
			
		||||
        self.assertSetEqual(set(), self.reg._ids)
 | 
			
		||||
        self.assertIsNone(self.reg.add(123))
 | 
			
		||||
        self.assertSetEqual({123}, self.reg._ids)
 | 
			
		||||
 | 
			
		||||
    def test_discard(self):
 | 
			
		||||
        self.reg._ids = {123}
 | 
			
		||||
        self.assertIsNone(self.reg.discard(0))
 | 
			
		||||
        self.assertIsNone(self.reg.discard(999))
 | 
			
		||||
        self.assertIsNone(self.reg.discard(123))
 | 
			
		||||
        self.assertSetEqual(set(), self.reg._ids)
 | 
			
		||||
 | 
			
		||||
    async def test_acquire(self):
 | 
			
		||||
        self.assertFalse(self.reg._lock.locked())
 | 
			
		||||
        await self.reg.acquire()
 | 
			
		||||
        self.assertTrue(self.reg._lock.locked())
 | 
			
		||||
 | 
			
		||||
    def test_release(self):
 | 
			
		||||
        self.reg._lock._locked = True
 | 
			
		||||
        self.assertTrue(self.reg._lock.locked())
 | 
			
		||||
        self.reg.release()
 | 
			
		||||
        self.assertFalse(self.reg._lock.locked())
 | 
			
		||||
 | 
			
		||||
    async def test_contextmanager(self):
 | 
			
		||||
        self.assertFalse(self.reg._lock.locked())
 | 
			
		||||
        async with self.reg as nothing:
 | 
			
		||||
            self.assertIsNone(nothing)
 | 
			
		||||
            self.assertTrue(self.reg._lock.locked())
 | 
			
		||||
        self.assertFalse(self.reg._lock.locked())
 | 
			
		||||
@@ -163,12 +163,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
 | 
			
		||||
    def test_is_full(self):
 | 
			
		||||
        self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
 | 
			
		||||
 | 
			
		||||
    def test_get_task_group_ids(self):
 | 
			
		||||
    def test_get_group_ids(self):
 | 
			
		||||
        group_name, ids = 'abcdef', [1, 2, 3]
 | 
			
		||||
        self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
 | 
			
		||||
        self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name))
 | 
			
		||||
        self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name))
 | 
			
		||||
        with self.assertRaises(exceptions.InvalidGroupName):
 | 
			
		||||
            self.task_pool.get_task_group_ids('something else')
 | 
			
		||||
            self.task_pool.get_group_ids(group_name, 'something else')
 | 
			
		||||
 | 
			
		||||
    async def test__check_start(self):
 | 
			
		||||
        self.task_pool._closed = True
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										43
									
								
								tests/test_queue_context.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								tests/test_queue_context.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
__author__ = "Daniil Fajnberg"
 | 
			
		||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
 | 
			
		||||
__license__ = """GNU LGPLv3.0
 | 
			
		||||
 | 
			
		||||
This file is part of asyncio-taskpool.
 | 
			
		||||
 | 
			
		||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
 | 
			
		||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
 | 
			
		||||
 | 
			
		||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; 
 | 
			
		||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
 | 
			
		||||
See the GNU Lesser General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. 
 | 
			
		||||
If not, see <https://www.gnu.org/licenses/>."""
 | 
			
		||||
 | 
			
		||||
__doc__ = """
 | 
			
		||||
Unittests for the `asyncio_taskpool.queue_context` module.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from unittest import IsolatedAsyncioTestCase
 | 
			
		||||
from unittest.mock import MagicMock, patch
 | 
			
		||||
 | 
			
		||||
from asyncio_taskpool.queue_context import Queue
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueueTestCase(IsolatedAsyncioTestCase):
 | 
			
		||||
    def test_item_processed(self):
 | 
			
		||||
        queue = Queue()
 | 
			
		||||
        queue._unfinished_tasks = 1000
 | 
			
		||||
        queue.item_processed()
 | 
			
		||||
        self.assertEqual(999, queue._unfinished_tasks)
 | 
			
		||||
 | 
			
		||||
    @patch.object(Queue, 'item_processed')
 | 
			
		||||
    async def test_contextmanager(self, mock_item_processed: MagicMock):
 | 
			
		||||
        queue = Queue()
 | 
			
		||||
        item = 'foo'
 | 
			
		||||
        queue.put_nowait(item)
 | 
			
		||||
        async with queue as item_from_queue:
 | 
			
		||||
            self.assertEqual(item, item_from_queue)
 | 
			
		||||
            mock_item_processed.assert_not_called()
 | 
			
		||||
        mock_item_processed.assert_called_once_with()
 | 
			
		||||
		Reference in New Issue
	
	Block a user