docstrings, refactorings, test adjustments

This commit is contained in:
Daniil Fajnberg 2022-02-05 22:26:02 +01:00
parent 3a27040a54
commit 9ec5359fd6
2 changed files with 52 additions and 30 deletions

View File

@ -15,20 +15,23 @@ log = logging.getLogger(__name__)
class BaseTaskPool: class BaseTaskPool:
"""The base class for task pools. Not intended to be used directly."""
_pools: List['BaseTaskPool'] = [] _pools: List['BaseTaskPool'] = []
@classmethod @classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int: def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
cls._pools.append(pool) cls._pools.append(pool)
return len(cls._pools) - 1 return len(cls._pools) - 1
# TODO: Make use of `max_running` def __init__(self, pool_size: int = inf, name: str = None) -> None:
def __init__(self, max_running: int = inf, name: str = None) -> None: """Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._max_running: int = max_running self.pool_size: int = pool_size
self._open: bool = True self._open: bool = True
self._counter: int = 0 self._counter: int = 0
self._running: Dict[int, Task] = {} self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {}
self._ending: int = 0
self._ended: Dict[int, Task] = {} self._ended: Dict[int, Task] = {}
self._idx: int = self._add_pool(self) self._idx: int = self._add_pool(self)
self._name: str = name self._name: str = name
@ -41,41 +44,64 @@ class BaseTaskPool:
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}' return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def max_running(self) -> int:
return self._max_running
@property @property
def is_open(self) -> bool: def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open return self._open
@property @property
def num_running(self) -> int: def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's final callback function is fired, it is no longer considered to be running.
"""
return len(self._running) return len(self._running)
@property @property
def num_cancelled(self) -> int: def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's cancel callback function is fired, it is considered cancelled and no longer running.
"""
return len(self._cancelled) return len(self._cancelled)
@property @property
def num_ended(self) -> int: def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's final callback function is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its cancel callback function has
returned, does it then actually end.
"""
return len(self._ended) return len(self._ended)
@property @property
def num_finished(self) -> int: def num_finished(self) -> int:
return self.num_ended - self.num_cancelled """
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
"""
return self.num_ended - self.num_cancelled + self._ending
@property @property
def is_full(self) -> bool: def is_full(self) -> bool:
"""
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
"""
return not self._more_allowed_flag.is_set() return not self._more_allowed_flag.is_set()
def _check_more_allowed(self) -> None: def _check_more_allowed(self) -> None:
if self.is_full and self.num_running < self.max_running: """
Sets or clears the internal event flag signalling whether or not the pool is full (i.e. whether more tasks can
be started), if the current state of the pool demands it.
"""
if self.is_full and self.num_running < self.pool_size:
self._more_allowed_flag.set() self._more_allowed_flag.set()
elif not self.is_full and self.num_running >= self.max_running: elif not self.is_full and self.num_running >= self.pool_size:
self._more_allowed_flag.clear() self._more_allowed_flag.clear()
def _task_name(self, task_id: int) -> str: def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}' return f'{self}_Task-{task_id}'
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
@ -83,6 +109,7 @@ class BaseTaskPool:
task = self._running.pop(task_id) task = self._running.pop(task_id)
assert task is not None assert task is not None
self._cancelled[task_id] = task self._cancelled[task_id] = task
self._ending += 1
await _execute_function(custom_callback, args=(task_id, )) await _execute_function(custom_callback, args=(task_id, ))
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
@ -90,6 +117,7 @@ class BaseTaskPool:
task = self._running.pop(task_id, None) task = self._running.pop(task_id, None)
if task is None: if task is None:
task = self._cancelled[task_id] task = self._cancelled[task_id]
self._ending -= 1
self._ended[task_id] = task self._ended[task_id] = task
await _execute_function(custom_callback, args=(task_id, )) await _execute_function(custom_callback, args=(task_id, ))
self._check_more_allowed() self._check_more_allowed()
@ -107,8 +135,9 @@ class BaseTaskPool:
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None, def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int: cancel_callback: CancelCallbackT = None) -> int:
if not (self._open or ignore_closed): if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
# TODO: Implement this (and the dependent user-facing methods) as async to wait for room in the pool
task_id = self._counter task_id = self._counter
self._counter += 1 self._counter += 1
self._running[task_id] = create_task( self._running[task_id] = create_task(

View File

@ -23,8 +23,8 @@ class BaseTaskPoolTestCase(TestCase):
self.mock___str__.return_value = self.mock_str = 'foobar' self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters: # Test pool parameters:
self.mock_pool_params = {'max_running': 420, 'name': 'test123'} self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(**self.mock_pool_params) self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None: def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools) setattr(pool.TaskPool, '_pools', self._pools)
@ -40,10 +40,11 @@ class BaseTaskPoolTestCase(TestCase):
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self): def test_init(self):
for key, value in self.mock_pool_params.items(): self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
self.assertEqual(value, getattr(self.task_pool, f'_{key}')) self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertEqual(0, self.task_pool._ending)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
@ -55,16 +56,12 @@ class BaseTaskPoolTestCase(TestCase):
def test___str__(self): def test___str__(self):
self.__str___patcher.stop() self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.mock_pool_params["name"]}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None) setattr(self.task_pool, '_name', None)
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
def test_max_running(self):
self.task_pool._max_running = foo = 'foo'
self.assertEqual(foo, self.task_pool.max_running)
def test_is_open(self): def test_is_open(self):
self.task_pool._open = foo = 'foo' self.task_pool._open = foo = 'foo'
self.assertEqual(foo, self.task_pool.is_open) self.assertEqual(foo, self.task_pool.is_open)
@ -86,31 +83,30 @@ class BaseTaskPoolTestCase(TestCase):
def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock): def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock):
mock_num_cancelled.return_value = cancelled = 69 mock_num_cancelled.return_value = cancelled = 69
mock_num_ended.return_value = ended = 420 mock_num_ended.return_value = ended = 420
self.assertEqual(ended - cancelled, self.task_pool.num_finished) self.task_pool._ending = mock_ending = 2
self.assertEqual(ended - cancelled + mock_ending, self.task_pool.num_finished)
mock_num_cancelled.assert_called_once_with()
mock_num_ended.assert_called_once_with()
def test_is_full(self): def test_is_full(self):
self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full) self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full)
@patch.object(pool.BaseTaskPool, 'max_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock) @patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock) @patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock)
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock, def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock):
mock_max_running: MagicMock):
def reset_mocks(): def reset_mocks():
mock_is_full.reset_mock() mock_is_full.reset_mock()
mock_num_running.reset_mock() mock_num_running.reset_mock()
mock_max_running.reset_mock()
self._check_more_allowed_patcher.stop() self._check_more_allowed_patcher.stop()
# Just reaching limit, we expect flag to become unset: # Just reaching limit, we expect flag to become unset:
mock_is_full.return_value = False mock_is_full.return_value = False
mock_num_running.return_value = mock_max_running.return_value = 420 mock_num_running.return_value = 420
self.task_pool._more_allowed_flag.clear() self.task_pool._more_allowed_flag.clear()
self.task_pool._check_more_allowed() self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set()) self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()]) mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with() mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks() reset_mocks()
# Already at limit, we expect nothing to change: # Already at limit, we expect nothing to change:
@ -119,7 +115,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertFalse(self.task_pool._more_allowed_flag.is_set()) self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()]) mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with() mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks() reset_mocks()
# Just finished a task, we expect flag to become set: # Just finished a task, we expect flag to become set:
@ -128,7 +123,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertTrue(self.task_pool._more_allowed_flag.is_set()) self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_called_once_with() mock_is_full.assert_called_once_with()
mock_num_running.assert_called_once_with() mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks() reset_mocks()
# In this state we expect the flag to remain unchanged change: # In this state we expect the flag to remain unchanged change:
@ -137,7 +131,6 @@ class BaseTaskPoolTestCase(TestCase):
self.assertTrue(self.task_pool._more_allowed_flag.is_set()) self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()]) mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with() mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
def test__task_name(self): def test__task_name(self):
i = 123 i = 123