docstrings, refactorings, test adjustments

This commit is contained in:
Daniil Fajnberg 2022-02-05 22:26:02 +01:00
parent 3a27040a54
commit 9ec5359fd6
2 changed files with 52 additions and 30 deletions

View File

@ -15,20 +15,23 @@ log = logging.getLogger(__name__)
class BaseTaskPool:
"""The base class for task pools. Not intended to be used directly."""
_pools: List['BaseTaskPool'] = []
@classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
cls._pools.append(pool)
return len(cls._pools) - 1
# TODO: Make use of `max_running`
def __init__(self, max_running: int = inf, name: str = None) -> None:
self._max_running: int = max_running
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self.pool_size: int = pool_size
self._open: bool = True
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
self._ending: int = 0
self._ended: Dict[int, Task] = {}
self._idx: int = self._add_pool(self)
self._name: str = name
@ -41,41 +44,64 @@ class BaseTaskPool:
def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def max_running(self) -> int:
return self._max_running
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open
@property
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's final callback function is fired, it is no longer considered to be running.
"""
return len(self._running)
@property
def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's cancel callback function is fired, it is considered cancelled and no longer running.
"""
return len(self._cancelled)
@property
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's final callback function is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its cancel callback function has
returned, does it then actually end.
"""
return len(self._ended)
@property
def num_finished(self) -> int:
return self.num_ended - self.num_cancelled
"""
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
"""
return self.num_ended - self.num_cancelled + self._ending
@property
def is_full(self) -> bool:
"""
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
"""
return not self._more_allowed_flag.is_set()
def _check_more_allowed(self) -> None:
if self.is_full and self.num_running < self.max_running:
"""
Sets or clears the internal event flag signalling whether or not the pool is full (i.e. whether more tasks can
be started), if the current state of the pool demands it.
"""
if self.is_full and self.num_running < self.pool_size:
self._more_allowed_flag.set()
elif not self.is_full and self.num_running >= self.max_running:
elif not self.is_full and self.num_running >= self.pool_size:
self._more_allowed_flag.clear()
def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}'
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
@ -83,6 +109,7 @@ class BaseTaskPool:
task = self._running.pop(task_id)
assert task is not None
self._cancelled[task_id] = task
self._ending += 1
await _execute_function(custom_callback, args=(task_id, ))
log.debug("Cancelled %s", self._task_name(task_id))
@ -90,6 +117,7 @@ class BaseTaskPool:
task = self._running.pop(task_id, None)
if task is None:
task = self._cancelled[task_id]
self._ending -= 1
self._ended[task_id] = task
await _execute_function(custom_callback, args=(task_id, ))
self._check_more_allowed()
@ -107,8 +135,9 @@ class BaseTaskPool:
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
if not (self._open or ignore_closed):
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
# TODO: Implement this (and the dependent user-facing methods) as async to wait for room in the pool
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task(

View File

@ -23,8 +23,8 @@ class BaseTaskPoolTestCase(TestCase):
self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters:
self.mock_pool_params = {'max_running': 420, 'name': 'test123'}
self.task_pool = pool.BaseTaskPool(**self.mock_pool_params)
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools)
@ -40,10 +40,11 @@ class BaseTaskPoolTestCase(TestCase):
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self):
for key, value in self.mock_pool_params.items():
self.assertEqual(value, getattr(self.task_pool, f'_{key}'))
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertEqual(0, self.task_pool._ending)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
@ -55,16 +56,12 @@ class BaseTaskPoolTestCase(TestCase):
def test___str__(self):
self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.mock_pool_params["name"]}'
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None)
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool))
def test_max_running(self):
self.task_pool._max_running = foo = 'foo'
self.assertEqual(foo, self.task_pool.max_running)
def test_is_open(self):
self.task_pool._open = foo = 'foo'
self.assertEqual(foo, self.task_pool.is_open)
@ -86,31 +83,30 @@ class BaseTaskPoolTestCase(TestCase):
def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock):
mock_num_cancelled.return_value = cancelled = 69
mock_num_ended.return_value = ended = 420
self.assertEqual(ended - cancelled, self.task_pool.num_finished)
self.task_pool._ending = mock_ending = 2
self.assertEqual(ended - cancelled + mock_ending, self.task_pool.num_finished)
mock_num_cancelled.assert_called_once_with()
mock_num_ended.assert_called_once_with()
def test_is_full(self):
self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full)
@patch.object(pool.BaseTaskPool, 'max_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock)
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock,
mock_max_running: MagicMock):
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock):
def reset_mocks():
mock_is_full.reset_mock()
mock_num_running.reset_mock()
mock_max_running.reset_mock()
self._check_more_allowed_patcher.stop()
# Just reaching limit, we expect flag to become unset:
mock_is_full.return_value = False
mock_num_running.return_value = mock_max_running.return_value = 420
mock_num_running.return_value = 420
self.task_pool._more_allowed_flag.clear()
self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# Already at limit, we expect nothing to change:
@ -119,7 +115,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# Just finished a task, we expect flag to become set:
@ -128,7 +123,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_called_once_with()
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# In this state we expect the flag to remain unchanged change:
@ -137,7 +131,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
def test__task_name(self):
i = 123