Add until_closed method to pools

This commit is contained in:
Daniil Fajnberg 2022-05-05 08:35:32 +02:00
parent a4ecf39157
commit 04bb398b90
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
4 changed files with 21 additions and 13 deletions

View File

@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg' author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '1.0.2' release = '1.1.0'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 1.0.3 version = 1.1.0
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -31,7 +31,7 @@ import logging
import warnings import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task, gather from asyncio.tasks import Task, create_task, gather
from contextlib import suppress from contextlib import suppress
from math import inf from math import inf
@ -72,7 +72,7 @@ class BaseTaskPool:
# Initialize flags; immutably set the name. # Initialize flags; immutably set the name.
self._locked: bool = False self._locked: bool = False
self._closed: bool = False self._closed: Event = Event()
self._name: str = name self._name: str = name
# The following three dictionaries are the actual containers of the tasks controlled by the pool. # The following three dictionaries are the actual containers of the tasks controlled by the pool.
@ -221,7 +221,7 @@ class BaseTaskPool:
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if function and not iscoroutinefunction(function): if function and not iscoroutinefunction(function):
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}") raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
if self._closed: if self._closed.is_set():
raise exceptions.PoolIsClosed("You must use another pool") raise exceptions.PoolIsClosed("You must use another pool")
if self._locked and not ignore_lock: if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
@ -550,9 +550,10 @@ class BaseTaskPool:
self._tasks_ended.clear() self._tasks_ended.clear()
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._closed = True self._closed.set()
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server. async def until_closed(self) -> bool:
return await self._closed.wait()
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
""" """
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(0, self.task_pool._num_started) self.assertEqual(0, self.task_pool._num_started)
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed) self.assertIsInstance(self.task_pool._closed, Event)
self.assertFalse(self.task_pool._closed.is_set())
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.get_group_ids(group_name, 'something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed.set()
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try: try:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._check_start(awaitable=None, function=mock_coroutine) self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None) self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False self.task_pool._closed.clear()
self.task_pool._locked = True self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked): with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
@ -461,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertTrue(self.task_pool._closed) self.assertTrue(self.task_pool._closed.is_set())
async def test_until_closed(self):
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
output = await self.task_pool.until_closed()
self.assertEqual(FOO, output)
self.task_pool._closed.wait.assert_awaited_once_with()
class TaskPoolTestCase(CommonTestCase): class TaskPoolTestCase(CommonTestCase):