refactoring; new flush method to clear memory of ended tasks

This commit is contained in:
Daniil Fajnberg 2022-02-06 13:42:34 +01:00
parent 2f0b08edf0
commit b5eed608b5
2 changed files with 26 additions and 22 deletions

View File

@ -32,8 +32,9 @@ class BaseTaskPool:
self._counter: int = 0 self._counter: int = 0
self._running: Dict[int, Task] = {} self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {}
self._ending: int = 0
self._ended: Dict[int, Task] = {} self._ended: Dict[int, Task] = {}
self._num_cancelled: int = 0
self._num_ended: int = 0
self._idx: int = self._add_pool(self) self._idx: int = self._add_pool(self)
self._name: str = name self._name: str = name
self._all_tasks_known_flag: Event = Event() self._all_tasks_known_flag: Event = Event()
@ -73,7 +74,7 @@ class BaseTaskPool:
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment). Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running. At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
""" """
return len(self._cancelled) return self._num_cancelled
@property @property
def num_ended(self) -> int: def num_ended(self) -> int:
@ -83,14 +84,14 @@ class BaseTaskPool:
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned, When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end. does it then actually end.
""" """
return len(self._ended) return self._num_ended
@property @property
def num_finished(self) -> int: def num_finished(self) -> int:
""" """
Returns the number of tasks in the pool that have actually finished running (without having been cancelled). Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
""" """
return self.num_ended - self.num_cancelled + self._ending return self._num_ended - self._num_cancelled + len(self._cancelled)
@property @property
def is_full(self) -> bool: def is_full(self) -> bool:
@ -110,16 +111,16 @@ class BaseTaskPool:
task = self._running.pop(task_id) task = self._running.pop(task_id)
assert task is not None assert task is not None
self._cancelled[task_id] = task self._cancelled[task_id] = task
self._ending += 1 self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await _execute_function(custom_callback, args=(task_id, ))
async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None: async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
task = self._running.pop(task_id, None) task = self._running.pop(task_id, None)
if task is None: if task is None:
task = self._cancelled[task_id] task = self._cancelled.pop(task_id)
self._ending -= 1
self._ended[task_id] = task self._ended[task_id] = task
self._num_ended += 1
self._enough_room.release() self._enough_room.release()
log.info("Ended %s", self._task_name(task_id)) log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await _execute_function(custom_callback, args=(task_id, ))
@ -170,6 +171,11 @@ class BaseTaskPool:
for task in self._running.values(): for task in self._running.values():
task.cancel(msg=msg) task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False):
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {}
return results
def close(self) -> None: def close(self) -> None:
self._open = False self._open = False
log.info("%s is closed!", str(self)) log.info("%s is closed!", str(self))
@ -178,8 +184,9 @@ class BaseTaskPool:
if self._open: if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait() await self._all_tasks_known_flag.wait()
results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
self._running = self._cancelled = self._ended = {} return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {}
return results return results

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import PropertyMock, patch
from asyncio_taskpool import pool from asyncio_taskpool import pool
@ -45,8 +45,9 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(0, self.task_pool._counter) self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertEqual(0, self.task_pool._ending)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
@ -83,22 +84,18 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(3, self.task_pool.num_running) self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self): def test_num_cancelled(self):
self.task_pool._cancelled = ['foo', 'bar', 'baz'] self.task_pool._num_cancelled = 33
self.assertEqual(3, self.task_pool.num_cancelled) self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self): def test_num_ended(self):
self.task_pool._ended = ['foo', 'bar', 'baz'] self.task_pool._num_ended = 3
self.assertEqual(3, self.task_pool.num_ended) self.assertEqual(3, self.task_pool.num_ended)
@patch.object(pool.BaseTaskPool, 'num_ended', new_callable=PropertyMock) def test_num_finished(self):
@patch.object(pool.BaseTaskPool, 'num_cancelled', new_callable=PropertyMock) self.task_pool._num_cancelled = cancelled = 69
def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock): self.task_pool._num_ended = ended = 420
mock_num_cancelled.return_value = cancelled = 69 self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
mock_num_ended.return_value = ended = 420 self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
self.task_pool._ending = mock_ending = 2
self.assertEqual(ended - cancelled + mock_ending, self.task_pool.num_finished)
mock_num_cancelled.assert_called_once_with()
mock_num_ended.assert_called_once_with()
def test_is_full(self): def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)