diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 906abde..c1e1d61 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -15,20 +15,23 @@ log = logging.getLogger(__name__) class BaseTaskPool: + """The base class for task pools. Not intended to be used directly.""" _pools: List['BaseTaskPool'] = [] @classmethod def _add_pool(cls, pool: 'BaseTaskPool') -> int: + """Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list.""" cls._pools.append(pool) return len(cls._pools) - 1 - # TODO: Make use of `max_running` - def __init__(self, max_running: int = inf, name: str = None) -> None: - self._max_running: int = max_running + def __init__(self, pool_size: int = inf, name: str = None) -> None: + """Initializes the necessary internal attributes and adds the new pool to the general pools list.""" + self.pool_size: int = pool_size self._open: bool = True self._counter: int = 0 self._running: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {} + self._ending: int = 0 self._ended: Dict[int, Task] = {} self._idx: int = self._add_pool(self) self._name: str = name @@ -41,41 +44,64 @@ class BaseTaskPool: def __str__(self) -> str: return f'{self.__class__.__name__}-{self._name or self._idx}' - @property - def max_running(self) -> int: - return self._max_running - @property def is_open(self) -> bool: + """Returns `True` if more the pool has not been closed yet.""" return self._open @property def num_running(self) -> int: + """ + Returns the number of tasks in the pool that are (at that moment) still running. + At the moment a task's final callback function is fired, it is no longer considered to be running. + """ return len(self._running) @property def num_cancelled(self) -> int: + """ + Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment). + At the moment a task's cancel callback function is fired, it is considered cancelled and no longer running. + """ return len(self._cancelled) @property def num_ended(self) -> int: + """ + Returns the number of tasks started through the pool that have stopped running (up until that moment). + At the moment a task's final callback function is fired, it is considered ended. + When a task is cancelled, it is not immediately considered ended; only after its cancel callback function has + returned, does it then actually end. + """ return len(self._ended) @property def num_finished(self) -> int: - return self.num_ended - self.num_cancelled + """ + Returns the number of tasks in the pool that have actually finished running (without having been cancelled). + """ + return self.num_ended - self.num_cancelled + self._ending @property def is_full(self) -> bool: + """ + Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size. + When the pool is full, any call to start a new task within it will block. + """ return not self._more_allowed_flag.is_set() def _check_more_allowed(self) -> None: - if self.is_full and self.num_running < self.max_running: + """ + Sets or clears the internal event flag signalling whether or not the pool is full (i.e. whether more tasks can + be started), if the current state of the pool demands it. + """ + if self.is_full and self.num_running < self.pool_size: self._more_allowed_flag.set() - elif not self.is_full and self.num_running >= self.max_running: + elif not self.is_full and self.num_running >= self.pool_size: self._more_allowed_flag.clear() def _task_name(self, task_id: int) -> str: + """Returns a standardized name for a task with a specific `task_id`.""" return f'{self}_Task-{task_id}' async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: @@ -83,6 +109,7 @@ class BaseTaskPool: task = self._running.pop(task_id) assert task is not None self._cancelled[task_id] = task + self._ending += 1 await _execute_function(custom_callback, args=(task_id, )) log.debug("Cancelled %s", self._task_name(task_id)) @@ -90,6 +117,7 @@ class BaseTaskPool: task = self._running.pop(task_id, None) if task is None: task = self._cancelled[task_id] + self._ending -= 1 self._ended[task_id] = task await _execute_function(custom_callback, args=(task_id, )) self._check_more_allowed() @@ -107,8 +135,9 @@ class BaseTaskPool: def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: - if not (self._open or ignore_closed): + if not (self.is_open or ignore_closed): raise exceptions.PoolIsClosed("Cannot start new tasks") + # TODO: Implement this (and the dependent user-facing methods) as async to wait for room in the pool task_id = self._counter self._counter += 1 self._running[task_id] = create_task( diff --git a/tests/test_pool.py b/tests/test_pool.py index a281526..fcd6bc4 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -23,8 +23,8 @@ class BaseTaskPoolTestCase(TestCase): self.mock___str__.return_value = self.mock_str = 'foobar' # Test pool parameters: - self.mock_pool_params = {'max_running': 420, 'name': 'test123'} - self.task_pool = pool.BaseTaskPool(**self.mock_pool_params) + self.test_pool_size, self.test_pool_name = 420, 'test123' + self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name) def tearDown(self) -> None: setattr(pool.TaskPool, '_pools', self._pools) @@ -40,10 +40,11 @@ class BaseTaskPoolTestCase(TestCase): self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) def test_init(self): - for key, value in self.mock_pool_params.items(): - self.assertEqual(value, getattr(self.task_pool, f'_{key}')) + self.assertEqual(self.test_pool_size, self.task_pool.pool_size) + self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) + self.assertEqual(0, self.task_pool._ending) self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) @@ -55,16 +56,12 @@ class BaseTaskPoolTestCase(TestCase): def test___str__(self): self.__str___patcher.stop() - expected_str = f'{pool.BaseTaskPool.__name__}-{self.mock_pool_params["name"]}' + expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}' self.assertEqual(expected_str, str(self.task_pool)) setattr(self.task_pool, '_name', None) expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' self.assertEqual(expected_str, str(self.task_pool)) - def test_max_running(self): - self.task_pool._max_running = foo = 'foo' - self.assertEqual(foo, self.task_pool.max_running) - def test_is_open(self): self.task_pool._open = foo = 'foo' self.assertEqual(foo, self.task_pool.is_open) @@ -86,31 +83,30 @@ class BaseTaskPoolTestCase(TestCase): def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock): mock_num_cancelled.return_value = cancelled = 69 mock_num_ended.return_value = ended = 420 - self.assertEqual(ended - cancelled, self.task_pool.num_finished) + self.task_pool._ending = mock_ending = 2 + self.assertEqual(ended - cancelled + mock_ending, self.task_pool.num_finished) + mock_num_cancelled.assert_called_once_with() + mock_num_ended.assert_called_once_with() def test_is_full(self): self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full) - @patch.object(pool.BaseTaskPool, 'max_running', new_callable=PropertyMock) @patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock) @patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock) - def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock, - mock_max_running: MagicMock): + def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock): def reset_mocks(): mock_is_full.reset_mock() mock_num_running.reset_mock() - mock_max_running.reset_mock() self._check_more_allowed_patcher.stop() # Just reaching limit, we expect flag to become unset: mock_is_full.return_value = False - mock_num_running.return_value = mock_max_running.return_value = 420 + mock_num_running.return_value = 420 self.task_pool._more_allowed_flag.clear() self.task_pool._check_more_allowed() self.assertFalse(self.task_pool._more_allowed_flag.is_set()) mock_is_full.assert_has_calls([call(), call()]) mock_num_running.assert_called_once_with() - mock_max_running.assert_called_once_with() reset_mocks() # Already at limit, we expect nothing to change: @@ -119,7 +115,6 @@ class BaseTaskPoolTestCase(TestCase): self.assertFalse(self.task_pool._more_allowed_flag.is_set()) mock_is_full.assert_has_calls([call(), call()]) mock_num_running.assert_called_once_with() - mock_max_running.assert_called_once_with() reset_mocks() # Just finished a task, we expect flag to become set: @@ -128,7 +123,6 @@ class BaseTaskPoolTestCase(TestCase): self.assertTrue(self.task_pool._more_allowed_flag.is_set()) mock_is_full.assert_called_once_with() mock_num_running.assert_called_once_with() - mock_max_running.assert_called_once_with() reset_mocks() # In this state we expect the flag to remain unchanged change: @@ -137,7 +131,6 @@ class BaseTaskPoolTestCase(TestCase): self.assertTrue(self.task_pool._more_allowed_flag.is_set()) mock_is_full.assert_has_calls([call(), call()]) mock_num_running.assert_called_once_with() - mock_max_running.assert_called_once_with() def test__task_name(self): i = 123