From dae883446affcc5fa2e421563401a1bf4f84efbd Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Thu, 5 May 2022 08:35:32 +0200 Subject: [PATCH] Add `until_closed` method to pools --- docs/source/conf.py | 2 +- setup.cfg | 2 +- src/asyncio_taskpool/pool.py | 13 +++++++------ tests/test_pool.py | 17 ++++++++++++----- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 185d4b7..78832b3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg' author = 'Daniil Fajnberg' # The full version, including alpha/beta/rc tags -release = '1.0.2' +release = '1.1.0' # -- General configuration --------------------------------------------------- diff --git a/setup.cfg b/setup.cfg index a96b408..ecdf8f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 1.0.3 +version = 1.1.0 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index ed8d9c4..4b33ba7 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -31,7 +31,7 @@ import logging import warnings from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.exceptions import CancelledError -from asyncio.locks import Semaphore +from asyncio.locks import Event, Semaphore from asyncio.tasks import Task, create_task, gather from contextlib import suppress from math import inf @@ -72,7 +72,7 @@ class BaseTaskPool: # Initialize flags; immutably set the name. self._locked: bool = False - self._closed: bool = False + self._closed: Event = Event() self._name: str = name # The following three dictionaries are the actual containers of the tasks controlled by the pool. @@ -221,7 +221,7 @@ class BaseTaskPool: raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") if function and not iscoroutinefunction(function): raise exceptions.NotCoroutine(f"Not a coroutine function: {function}") - if self._closed: + if self._closed.is_set(): raise exceptions.PoolIsClosed("You must use another pool") if self._locked and not ignore_lock: raise exceptions.PoolIsLocked("Cannot start new tasks") @@ -550,9 +550,10 @@ class BaseTaskPool: self._tasks_ended.clear() self._tasks_cancelled.clear() self._tasks_running.clear() - self._closed = True - # TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will - # await it to allow blocking until a closing command comes from a server. + self._closed.set() + + async def until_closed(self) -> bool: + return await self._closed.wait() class TaskPool(BaseTaskPool): diff --git a/tests/test_pool.py b/tests/test_pool.py index dbcc1af..ce5ac08 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module. """ from asyncio.exceptions import CancelledError -from asyncio.locks import Semaphore +from asyncio.locks import Event, Semaphore from unittest import IsolatedAsyncioTestCase from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from typing import Type @@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertEqual(0, self.task_pool._num_started) self.assertFalse(self.task_pool._locked) - self.assertFalse(self.task_pool._closed) + self.assertIsInstance(self.task_pool._closed, Event) + self.assertFalse(self.task_pool._closed.is_set()) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) @@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool.get_group_ids(group_name, 'something else') async def test__check_start(self): - self.task_pool._closed = True + self.task_pool._closed.set() mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() try: with self.assertRaises(AssertionError): @@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool._check_start(awaitable=None, function=mock_coroutine) with self.assertRaises(exceptions.PoolIsClosed): self.task_pool._check_start(awaitable=mock_coroutine, function=None) - self.task_pool._closed = False + self.task_pool._closed.clear() self.task_pool._locked = True with self.assertRaises(exceptions.PoolIsLocked): self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) @@ -461,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) - self.assertTrue(self.task_pool._closed) + self.assertTrue(self.task_pool._closed.is_set()) + + async def test_until_closed(self): + self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO)) + output = await self.task_pool.until_closed() + self.assertEqual(FOO, output) + self.task_pool._closed.wait.assert_awaited_once_with() class TaskPoolTestCase(CommonTestCase):