From 2f0b08edf017e8c405f0ccb1f43057037f678d15 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 6 Feb 2022 13:08:39 +0100 Subject: [PATCH] implemented working pool size limit; adjusted tests and examples; small renaming --- src/asyncio_taskpool/pool.py | 142 +++++++++++++++++---------------- src/asyncio_taskpool/server.py | 6 +- src/asyncio_taskpool/types.py | 2 +- tests/test_pool.py | 67 +++++----------- usage/USAGE.md | 4 +- usage/example_server.py | 2 +- 6 files changed, 98 insertions(+), 125 deletions(-) diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index c1e1d61..cf4299d 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -2,13 +2,13 @@ import logging from asyncio import gather from asyncio.coroutines import iscoroutinefunction from asyncio.exceptions import CancelledError -from asyncio.locks import Event +from asyncio.locks import Event, Semaphore from asyncio.tasks import Task, create_task from math import inf from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from . import exceptions -from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT +from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT log = logging.getLogger(__name__) @@ -26,7 +26,8 @@ class BaseTaskPool: def __init__(self, pool_size: int = inf, name: str = None) -> None: """Initializes the necessary internal attributes and adds the new pool to the general pools list.""" - self.pool_size: int = pool_size + self._enough_room: Semaphore = Semaphore() + self.pool_size = pool_size self._open: bool = True self._counter: int = 0 self._running: Dict[int, Task] = {} @@ -37,13 +38,22 @@ class BaseTaskPool: self._name: str = name self._all_tasks_known_flag: Event = Event() self._all_tasks_known_flag.set() - self._more_allowed_flag: Event = Event() - self._check_more_allowed() log.debug("%s initialized", str(self)) def __str__(self) -> str: return f'{self.__class__.__name__}-{self._name or self._idx}' + @property + def pool_size(self) -> int: + return self._pool_size + + @pool_size.setter + def pool_size(self, value: int) -> None: + if value < 0: + raise ValueError("Pool size can not be less than 0") + self._enough_room._value = value + self._pool_size = value + @property def is_open(self) -> bool: """Returns `True` if more the pool has not been closed yet.""" @@ -53,7 +63,7 @@ class BaseTaskPool: def num_running(self) -> int: """ Returns the number of tasks in the pool that are (at that moment) still running. - At the moment a task's final callback function is fired, it is no longer considered to be running. + At the moment a task's `end_callback` is fired, it is no longer considered to be running. """ return len(self._running) @@ -61,7 +71,7 @@ class BaseTaskPool: def num_cancelled(self) -> int: """ Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment). - At the moment a task's cancel callback function is fired, it is considered cancelled and no longer running. + At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running. """ return len(self._cancelled) @@ -69,9 +79,9 @@ class BaseTaskPool: def num_ended(self) -> int: """ Returns the number of tasks started through the pool that have stopped running (up until that moment). - At the moment a task's final callback function is fired, it is considered ended. - When a task is cancelled, it is not immediately considered ended; only after its cancel callback function has - returned, does it then actually end. + At the moment a task's `end_callback` is fired, it is considered ended. + When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned, + does it then actually end. """ return len(self._ended) @@ -88,18 +98,9 @@ class BaseTaskPool: Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size. When the pool is full, any call to start a new task within it will block. """ - return not self._more_allowed_flag.is_set() - - def _check_more_allowed(self) -> None: - """ - Sets or clears the internal event flag signalling whether or not the pool is full (i.e. whether more tasks can - be started), if the current state of the pool demands it. - """ - if self.is_full and self.num_running < self.pool_size: - self._more_allowed_flag.set() - elif not self.is_full and self.num_running >= self.pool_size: - self._more_allowed_flag.clear() + return self._enough_room.locked() + # TODO: Consider adding task group names def _task_name(self, task_id: int) -> str: """Returns a standardized name for a task with a specific `task_id`.""" return f'{self}_Task-{task_id}' @@ -110,20 +111,20 @@ class BaseTaskPool: assert task is not None self._cancelled[task_id] = task self._ending += 1 - await _execute_function(custom_callback, args=(task_id, )) log.debug("Cancelled %s", self._task_name(task_id)) + await _execute_function(custom_callback, args=(task_id, )) - async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None: + async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None: task = self._running.pop(task_id, None) if task is None: task = self._cancelled[task_id] self._ending -= 1 self._ended[task_id] = task - await _execute_function(custom_callback, args=(task_id, )) - self._check_more_allowed() + self._enough_room.release() log.info("Ended %s", self._task_name(task_id)) + await _execute_function(custom_callback, args=(task_id, )) - async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None, + async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Any: log.info("Started %s", self._task_name(task_id)) try: @@ -131,20 +132,23 @@ class BaseTaskPool: except CancelledError: await self._cancel_task(task_id, custom_callback=cancel_callback) finally: - await self._end_task(task_id, custom_callback=final_callback) + await self._end_task(task_id, custom_callback=end_callback) - def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None, - cancel_callback: CancelCallbackT = None) -> int: + async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None, + cancel_callback: CancelCallbackT = None) -> int: if not (self.is_open or ignore_closed): raise exceptions.PoolIsClosed("Cannot start new tasks") - # TODO: Implement this (and the dependent user-facing methods) as async to wait for room in the pool - task_id = self._counter - self._counter += 1 - self._running[task_id] = create_task( - self._task_wrapper(awaitable, task_id, final_callback, cancel_callback), - name=self._task_name(task_id) - ) - self._check_more_allowed() + await self._enough_room.acquire() + try: + task_id = self._counter + self._counter += 1 + self._running[task_id] = create_task( + self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), + name=self._task_name(task_id) + ) + except Exception as e: + self._enough_room.release() + raise e return task_id def _cancel_one(self, task_id: int, msg: str = None) -> None: @@ -180,15 +184,15 @@ class BaseTaskPool: class TaskPool(BaseTaskPool): - def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: + async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: if kwargs is None: kwargs = {} - return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback) + return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) - def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: - return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num)) + async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: + return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)) @staticmethod def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: @@ -204,54 +208,54 @@ class TaskPool(BaseTaskPool): return func(**arg) raise ValueError - def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: if self._all_tasks_known_flag.is_set(): self._all_tasks_known_flag.clear() args_iter = iter(args_iter) - def _start_next_coroutine() -> bool: + async def _start_next_coroutine() -> bool: cor = self._get_next_coroutine(func, args_iter, arg_stars) if cor is None: self._all_tasks_known_flag.set() return True - self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback) + await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback) return False async def _start_next(task_id: int) -> None: - await _execute_function(final_callback, args=(task_id, )) - _start_next_coroutine() + await _start_next_coroutine() + await _execute_function(end_callback, args=(task_id, )) for _ in range(num_tasks): - reached_end = _start_next_coroutine() + reached_end = await _start_next_coroutine() if reached_end: break - def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: - self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks, - final_callback=final_callback, cancel_callback=cancel_callback) + async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks, + end_callback=end_callback, cancel_callback=cancel_callback) - def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: - self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, - final_callback=final_callback, cancel_callback=cancel_callback) + async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, + end_callback=end_callback, cancel_callback=cancel_callback) - def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: - self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, - final_callback=final_callback, cancel_callback=cancel_callback) + async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, + end_callback=end_callback, cancel_callback=cancel_callback) class SimpleTaskPool(BaseTaskPool): def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, name: str = None) -> None: self._func: CoroutineFunc = func self._args: ArgsT = args self._kwargs: KwArgsT = kwargs if kwargs is not None else {} - self._final_callback: FinalCallbackT = final_callback + self._end_callback: EndCallbackT = end_callback self._cancel_callback: CancelCallbackT = cancel_callback super().__init__(name=name) @@ -263,12 +267,12 @@ class SimpleTaskPool(BaseTaskPool): def size(self) -> int: return self.num_running - def _start_one(self) -> int: - return self._start_task(self._func(*self._args, **self._kwargs), - final_callback=self._final_callback, cancel_callback=self._cancel_callback) + async def _start_one(self) -> int: + return await self._start_task(self._func(*self._args, **self._kwargs), + end_callback=self._end_callback, cancel_callback=self._cancel_callback) - def start(self, num: int = 1) -> List[int]: - return [self._start_one() for _ in range(num)] + async def start(self, num: int = 1) -> List[int]: + return [await self._start_one() for _ in range(num)] def stop(self, num: int = 1) -> List[int]: num = min(num, self.size) diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 421e2b3..8db8e98 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -45,11 +45,11 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta self._server_kwargs = server_kwargs self._server: Optional[AbstractServer] = None - def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: + async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: if num is None: num = 1 log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) - writer.write(str(self._pool.start(num)).encode()) + writer.write(str(await self._pool.start(num)).encode()) def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: if num is None: @@ -78,7 +78,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta break cmd, arg = get_cmd_arg(msg) if cmd == constants.CMD_START: - self._start_tasks(writer, arg) + await self._start_tasks(writer, arg) elif cmd == constants.CMD_STOP: self._stop_tasks(writer, arg) elif cmd == constants.CMD_STOP_ALL: diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index 2caf9a6..649e58a 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -5,7 +5,7 @@ from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union ArgsT = Iterable[Any] KwArgsT = Mapping[str, Any] CoroutineFunc = Callable[[...], Awaitable[Any]] -FinalCallbackT = Callable +EndCallbackT = Callable CancelCallbackT = Callable ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] diff --git a/tests/test_pool.py b/tests/test_pool.py index 4b1e1d8..ec2a4e4 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,6 @@ import asyncio from unittest import TestCase -from unittest.mock import MagicMock, PropertyMock, patch, call +from unittest.mock import MagicMock, PropertyMock, patch from asyncio_taskpool import pool @@ -14,10 +14,10 @@ class BaseTaskPoolTestCase(TestCase): # These three methods are called during initialization, so we mock them by default during setup self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') - self._check_more_allowed_patcher = patch.object(pool.BaseTaskPool, '_check_more_allowed') + self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock) self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') self.mock__add_pool = self._add_pool_patcher.start() - self.mock__check_more_allowed = self._check_more_allowed_patcher.start() + self.mock_pool_size = self.pool_size_patcher.start() self.mock___str__ = self.__str___patcher.start() self.mock__add_pool.return_value = self.mock_idx = 123 self.mock___str__.return_value = self.mock_str = 'foobar' @@ -29,7 +29,7 @@ class BaseTaskPoolTestCase(TestCase): def tearDown(self) -> None: setattr(pool.TaskPool, '_pools', self._pools) self._add_pool_patcher.stop() - self._check_more_allowed_patcher.stop() + self.pool_size_patcher.stop() self.__str___patcher.stop() def test__add_pool(self): @@ -40,7 +40,7 @@ class BaseTaskPoolTestCase(TestCase): self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) def test_init(self): - self.assertEqual(self.test_pool_size, self.task_pool.pool_size) + self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) self.assertTrue(self.task_pool._open) self.assertEqual(0, self.task_pool._counter) self.assertDictEqual(EMPTY_DICT, self.task_pool._running) @@ -51,9 +51,8 @@ class BaseTaskPoolTestCase(TestCase): self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertTrue(self.task_pool._all_tasks_known_flag.is_set()) - self.assertIsInstance(self.task_pool._more_allowed_flag, asyncio.locks.Event) self.mock__add_pool.assert_called_once_with(self.task_pool) - self.mock__check_more_allowed.assert_called_once_with() + self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock___str__.assert_called_once_with() def test___str__(self): @@ -64,6 +63,17 @@ class BaseTaskPoolTestCase(TestCase): expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' self.assertEqual(expected_str, str(self.task_pool)) + def test_pool_size(self): + self.pool_size_patcher.stop() + self.task_pool._pool_size = self.test_pool_size + self.assertEqual(self.test_pool_size, self.task_pool.pool_size) + + with self.assertRaises(ValueError): + self.task_pool.pool_size = -1 + + self.task_pool.pool_size = new_size = 69 + self.assertEqual(new_size, self.task_pool._pool_size) + def test_is_open(self): self.task_pool._open = foo = 'foo' self.assertEqual(foo, self.task_pool.is_open) @@ -91,48 +101,7 @@ class BaseTaskPoolTestCase(TestCase): mock_num_ended.assert_called_once_with() def test_is_full(self): - self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full) - - @patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock) - @patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock) - def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock): - def reset_mocks(): - mock_is_full.reset_mock() - mock_num_running.reset_mock() - self._check_more_allowed_patcher.stop() - - # Just reaching limit, we expect flag to become unset: - mock_is_full.return_value = False - mock_num_running.return_value = 420 - self.task_pool._more_allowed_flag.clear() - self.task_pool._check_more_allowed() - self.assertFalse(self.task_pool._more_allowed_flag.is_set()) - mock_is_full.assert_has_calls([call(), call()]) - mock_num_running.assert_called_once_with() - reset_mocks() - - # Already at limit, we expect nothing to change: - mock_is_full.return_value = True - self.task_pool._check_more_allowed() - self.assertFalse(self.task_pool._more_allowed_flag.is_set()) - mock_is_full.assert_has_calls([call(), call()]) - mock_num_running.assert_called_once_with() - reset_mocks() - - # Just finished a task, we expect flag to become set: - mock_num_running.return_value = 419 - self.task_pool._check_more_allowed() - self.assertTrue(self.task_pool._more_allowed_flag.is_set()) - mock_is_full.assert_called_once_with() - mock_num_running.assert_called_once_with() - reset_mocks() - - # In this state we expect the flag to remain unchanged change: - mock_is_full.return_value = False - self.task_pool._check_more_allowed() - self.assertTrue(self.task_pool._more_allowed_flag.is_set()) - mock_is_full.assert_has_calls([call(), call()]) - mock_num_running.assert_called_once_with() + self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) def test__task_name(self): i = 123 diff --git a/usage/USAGE.md b/usage/USAGE.md index 8383329..3ac2ba7 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -33,9 +33,9 @@ async def work(n: int) -> None: async def main() -> None: pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet - pool.start(3) # launches work tasks 0, 1, and 2 + await pool.start(3) # launches work tasks 0, 1, and 2 await asyncio.sleep(1.5) # lets the tasks work for a bit - pool.start() # launches work task 3 + await pool.start() # launches work task 3 await asyncio.sleep(1.5) # lets the tasks work for a bit pool.stop(2) # cancels tasks 3 and 2 pool.close() # required for the last line diff --git a/usage/example_server.py b/usage/example_server.py index 5acc831..0478f43 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -44,7 +44,7 @@ async def main() -> None: for item in range(100): q.put_nowait(item) pool = SimpleTaskPool(worker, (q,)) # initializes the pool - pool.start(3) # launches three worker tasks + await pool.start(3) # launches three worker tasks control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() # We block until `.task_done()` has been called once by our workers for every item placed into the queue. await q.join()