From b5eed608b53a5d205f306cade374428475a84552 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 6 Feb 2022 13:42:34 +0100 Subject: [PATCH] refactoring; new `flush` method to clear memory of ended tasks --- src/asyncio_taskpool/pool.py | 25 ++++++++++++++++--------- tests/test_pool.py | 23 ++++++++++------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index cf4299d..1344388 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -32,8 +32,9 @@ class BaseTaskPool: self._counter: int = 0 self._running: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {} - self._ending: int = 0 self._ended: Dict[int, Task] = {} + self._num_cancelled: int = 0 + self._num_ended: int = 0 self._idx: int = self._add_pool(self) self._name: str = name self._all_tasks_known_flag: Event = Event() @@ -73,7 +74,7 @@ class BaseTaskPool: Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment). At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running. """ - return len(self._cancelled) + return self._num_cancelled @property def num_ended(self) -> int: @@ -83,14 +84,14 @@ class BaseTaskPool: When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned, does it then actually end. """ - return len(self._ended) + return self._num_ended @property def num_finished(self) -> int: """ Returns the number of tasks in the pool that have actually finished running (without having been cancelled). """ - return self.num_ended - self.num_cancelled + self._ending + return self._num_ended - self._num_cancelled + len(self._cancelled) @property def is_full(self) -> bool: @@ -110,16 +111,16 @@ class BaseTaskPool: task = self._running.pop(task_id) assert task is not None self._cancelled[task_id] = task - self._ending += 1 + self._num_cancelled += 1 log.debug("Cancelled %s", self._task_name(task_id)) await _execute_function(custom_callback, args=(task_id, )) async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None: task = self._running.pop(task_id, None) if task is None: - task = self._cancelled[task_id] - self._ending -= 1 + task = self._cancelled.pop(task_id) self._ended[task_id] = task + self._num_ended += 1 self._enough_room.release() log.info("Ended %s", self._task_name(task_id)) await _execute_function(custom_callback, args=(task_id, )) @@ -170,6 +171,11 @@ class BaseTaskPool: for task in self._running.values(): task.cancel(msg=msg) + async def flush(self, return_exceptions: bool = False): + results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) + self._ended = self._cancelled = {} + return results + def close(self) -> None: self._open = False log.info("%s is closed!", str(self)) @@ -178,8 +184,9 @@ class BaseTaskPool: if self._open: raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") await self._all_tasks_known_flag.wait() - results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions) - self._running = self._cancelled = self._ended = {} + results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), + return_exceptions=return_exceptions) + self._ended = self._cancelled = self._running = {} return results diff --git a/tests/test_pool.py b/tests/test_pool.py index ec2a4e4..96b4641 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,6 @@ import asyncio from unittest import TestCase -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import PropertyMock, patch from asyncio_taskpool import pool @@ -45,8 +45,9 @@ class BaseTaskPoolTestCase(TestCase): self.assertEqual(0, self.task_pool._counter) self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) - self.assertEqual(0, self.task_pool._ending) self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) + self.assertEqual(0, self.task_pool._num_cancelled) + self.assertEqual(0, self.task_pool._num_ended) self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) @@ -83,22 +84,18 @@ class BaseTaskPoolTestCase(TestCase): self.assertEqual(3, self.task_pool.num_running) def test_num_cancelled(self): - self.task_pool._cancelled = ['foo', 'bar', 'baz'] + self.task_pool._num_cancelled = 33 self.assertEqual(3, self.task_pool.num_cancelled) def test_num_ended(self): - self.task_pool._ended = ['foo', 'bar', 'baz'] + self.task_pool._num_ended = 3 self.assertEqual(3, self.task_pool.num_ended) - @patch.object(pool.BaseTaskPool, 'num_ended', new_callable=PropertyMock) - @patch.object(pool.BaseTaskPool, 'num_cancelled', new_callable=PropertyMock) - def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock): - mock_num_cancelled.return_value = cancelled = 69 - mock_num_ended.return_value = ended = 420 - self.task_pool._ending = mock_ending = 2 - self.assertEqual(ended - cancelled + mock_ending, self.task_pool.num_finished) - mock_num_cancelled.assert_called_once_with() - mock_num_ended.assert_called_once_with() + def test_num_finished(self): + self.task_pool._num_cancelled = cancelled = 69 + self.task_pool._num_ended = ended = 420 + self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'} + self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished) def test_is_full(self): self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)