From bc9d2f243e953b0db19ca0b45a96e3f7c677f3ff Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Tue, 8 Feb 2022 23:09:33 +0100 Subject: [PATCH] renamed "closing" a pool to "locking" it --- setup.cfg | 2 +- src/asyncio_taskpool/exceptions.py | 4 +- src/asyncio_taskpool/pool.py | 57 +++++++++++-------- tests/test_pool.py | 91 +++++++++++++++--------------- usage/USAGE.md | 14 ++--- usage/example_server.py | 4 +- 6 files changed, 88 insertions(+), 84 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8d92626..5e2d782 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.1.7 +version = 0.2.0 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index 8176e4e..b717de5 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -2,7 +2,7 @@ class PoolException(Exception): pass -class PoolIsClosed(PoolException): +class PoolIsLocked(PoolException): pass @@ -22,7 +22,7 @@ class InvalidTaskID(PoolException): pass -class PoolStillOpen(PoolException): +class PoolStillUnlocked(PoolException): pass diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 5a08c68..489b0f7 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -31,7 +31,7 @@ class BaseTaskPool: """Initializes the necessary internal attributes and adds the new pool to the general pools list.""" self._enough_room: Semaphore = Semaphore() self.pool_size = pool_size - self._open: bool = True + self._locked: bool = False self._counter: int = 0 self._running: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {} @@ -71,9 +71,21 @@ class BaseTaskPool: self._pool_size = value @property - def is_open(self) -> bool: - """Returns `True` if more the pool has not been closed yet.""" - return self._open + def is_locked(self) -> bool: + """Returns `True` if more the pool has been locked (see below).""" + return self._locked + + def lock(self) -> None: + """Disallows any more tasks to be started in the pool.""" + if not self._locked: + self._locked = True + log.info("%s is locked!", str(self)) + + def unlock(self) -> None: + """Allows new tasks to be started in the pool.""" + if self._locked: + self._locked = False + log.info("%s was unlocked.", str(self)) @property def num_running(self) -> int: @@ -187,7 +199,7 @@ class BaseTaskPool: finally: await self._task_ending(task_id, custom_callback=end_callback) - async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None, + async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: """ Starts a coroutine as a new task in the pool. @@ -197,8 +209,8 @@ class BaseTaskPool: Args: awaitable: The actual coroutine to be run within the task pool. - ignore_closed (optional): - If `True`, even if the pool is closed, the task will still be started. + ignore_lock (optional): + If `True`, even if the pool is locked, the task will still be started. end_callback (optional): A callback to execute after the task has ended. It is run with the task's ID as its only positional argument. @@ -208,12 +220,12 @@ class BaseTaskPool: Raises: `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine. - `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`. + `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`. """ if not iscoroutine(awaitable): raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") - if not (self.is_open or ignore_closed): - raise exceptions.PoolIsClosed("Cannot start new tasks") + if self._locked and not ignore_lock: + raise exceptions.PoolIsLocked("Cannot start new tasks") await self._enough_room.acquire() task_id = self._counter self._counter += 1 @@ -303,16 +315,11 @@ class BaseTaskPool: self._interrupt_flag.clear() return results - def close(self) -> None: - """Disallows any more tasks to be started in the pool.""" - self._open = False - log.info("%s is closed!", str(self)) - async def gather(self, return_exceptions: bool = False): """ Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks. - The `close()` method must have been called prior to this. + The `lock()` method must have been called prior to this. Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller @@ -327,10 +334,10 @@ class BaseTaskPool: return_exceptions (optional): Passed directly into `gather`. Raises: - `asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet. + `asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet. """ - if self._open: - raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") + if not self._locked: + raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered") await gather(*self._before_gathering) results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), return_exceptions=return_exceptions) @@ -415,7 +422,7 @@ class TaskPool(BaseTaskPool): Raises: `NotCoroutine` if `func` is not a coroutine function. - `PoolIsClosed` if the pool has been closed already. + `PoolIsLocked` if the pool has been locked already. """ ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))) # TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions @@ -468,7 +475,7 @@ class TaskPool(BaseTaskPool): try: await self._start_task( star_function(func, arg, arg_stars=arg_stars), - ignore_closed=True, + ignore_lock=True, end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started, func=func, arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback), @@ -579,10 +586,10 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Raises: - `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed. + `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked. """ - if not self.is_open: - raise exceptions.PoolIsClosed("Cannot start new tasks") + if not self._locked: + raise exceptions.PoolIsLocked("Cannot start new tasks") args_queue = self._set_up_args_queue(args_iter, num_tasks) # We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the # `_queue_callback` triggered by one or more of them. @@ -628,7 +635,7 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Raises: - `PoolIsClosed` if the pool has been closed. + `PoolIsLocked` if the pool has been locked. `NotCoroutine` if `func` is not a coroutine function. """ await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks, diff --git a/tests/test_pool.py b/tests/test_pool.py index e483957..bc41e2b 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -68,7 +68,7 @@ class BaseTaskPoolTestCase(CommonTestCase): def test_init(self): self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) - self.assertTrue(self.task_pool._open) + self.assertFalse(self.task_pool._locked) self.assertEqual(0, self.task_pool._counter) self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) @@ -103,9 +103,23 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool.pool_size = new_size = 69 self.assertEqual(new_size, self.task_pool._pool_size) - def test_is_open(self): - self.task_pool._open = FOO - self.assertEqual(FOO, self.task_pool.is_open) + def test_is_locked(self): + self.task_pool._locked = FOO + self.assertEqual(FOO, self.task_pool.is_locked) + + def test_lock(self): + assert not self.task_pool._locked + self.task_pool.lock() + self.assertTrue(self.task_pool._locked) + self.task_pool.lock() + self.assertTrue(self.task_pool._locked) + + def test_unlock(self): + self.task_pool._locked = True + self.task_pool.unlock() + self.assertFalse(self.task_pool._locked) + self.task_pool.unlock() + self.assertFalse(self.task_pool._locked) def test_num_running(self): self.task_pool._running = ['foo', 'bar', 'baz'] @@ -211,11 +225,9 @@ class BaseTaskPoolTestCase(CommonTestCase): @patch.object(pool, 'create_task') @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) - @patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock) - async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock, - mock__task_wrapper: AsyncMock, mock_create_task: MagicMock): + async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock, + mock_create_task: MagicMock): def reset_mocks() -> None: - mock_is_open.reset_mock() mock__task_name.reset_mock() mock__task_wrapper.reset_mock() mock_create_task.reset_mock() @@ -226,31 +238,27 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool._counter = count = 123 self.task_pool._enough_room._value = room = 123 + def check_nothing_changed() -> None: + self.assertEqual(count, self.task_pool._counter) + self.assertNotIn(count, self.task_pool._running) + self.assertEqual(room, self.task_pool._enough_room._value) + mock__task_name.assert_not_called() + mock__task_wrapper.assert_not_called() + mock_create_task.assert_not_called() + reset_mocks() + with self.assertRaises(exceptions.NotCoroutine): await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) - self.assertEqual(count, self.task_pool._counter) - self.assertNotIn(count, self.task_pool._running) - self.assertEqual(room, self.task_pool._enough_room._value) - mock_is_open.assert_not_called() - mock__task_name.assert_not_called() - mock__task_wrapper.assert_not_called() - mock_create_task.assert_not_called() - reset_mocks() + check_nothing_changed() - mock_is_open.return_value = ignore_closed = False + self.task_pool._locked = True + ignore_closed = False mock_awaitable = mock_coroutine() - with self.assertRaises(exceptions.PoolIsClosed): + with self.assertRaises(exceptions.PoolIsLocked): await self.task_pool._start_task(mock_awaitable, ignore_closed, end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) await mock_awaitable - self.assertEqual(count, self.task_pool._counter) - self.assertNotIn(count, self.task_pool._running) - self.assertEqual(room, self.task_pool._enough_room._value) - mock_is_open.assert_called_once_with() - mock__task_name.assert_not_called() - mock__task_wrapper.assert_not_called() - mock_create_task.assert_not_called() - reset_mocks() + check_nothing_changed() ignore_closed = True mock_awaitable = mock_coroutine() @@ -261,7 +269,6 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(mock_task, self.task_pool._running[count]) self.assertEqual(room - 1, self.task_pool._enough_room._value) - mock_is_open.assert_called_once_with() mock__task_name.assert_called_once_with(count) mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) @@ -280,7 +287,6 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertEqual(count + 1, self.task_pool._counter) self.assertNotIn(count, self.task_pool._running) self.assertEqual(room, self.task_pool._enough_room._value) - mock_is_open.assert_called_once_with() mock__task_name.assert_called_once_with(count) mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) @@ -345,11 +351,6 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) - def test_close(self): - assert self.task_pool._open - self.task_pool.close() - self.assertFalse(self.task_pool._open) - async def test_gather(self): test_exception = TestException() mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) @@ -361,8 +362,8 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool._running = running = {789: mock_running_func()} self.task_pool._interrupt_flag.set() - assert self.task_pool._open - with self.assertRaises(exceptions.PoolStillOpen): + assert not self.task_pool._locked + with self.assertRaises(exceptions.PoolStillUnlocked): await self.task_pool.gather() self.assertDictEqual(self.task_pool._ended, ended) self.assertDictEqual(self.task_pool._cancelled, cancelled) @@ -370,7 +371,7 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertListEqual(self.task_pool._before_gathering, before_gather) self.assertTrue(self.task_pool._interrupt_flag.is_set()) - self.task_pool._open = False + self.task_pool._locked = True def check_assertions(output) -> None: self.assertListEqual([FOO, test_exception, BAR], output) @@ -449,7 +450,7 @@ class TaskPoolTestCase(CommonTestCase): mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock() self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) self.assertTrue(q.empty()) - mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True, + mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True, end_callback=queue_callback, cancel_callback=cancel_cb) mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars) mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool, @@ -459,7 +460,7 @@ class TaskPoolTestCase(CommonTestCase): mock_star_function.reset_mock() mock_partial.reset_mock() - self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb)) + self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) self.assertTrue(q.empty()) mock__start_task.assert_not_awaited() mock_star_function.assert_not_called() @@ -528,9 +529,8 @@ class TaskPoolTestCase(CommonTestCase): @patch.object(pool, 'Event') @patch.object(pool.TaskPool, '_queue_consumer') @patch.object(pool.TaskPool, '_set_up_args_queue') - @patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock) - async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock, - mock__queue_consumer: AsyncMock, mock_event_cls: MagicMock): + async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock, + mock_event_cls: MagicMock): qsize = 4 mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize)) mock_flag_set = MagicMock() @@ -540,17 +540,14 @@ class TaskPoolTestCase(CommonTestCase): args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2 end_cb, cancel_cb = MagicMock(), MagicMock() - mock_is_open.return_value = False - with self.assertRaises(exceptions.PoolIsClosed): + self.task_pool._locked = False + with self.assertRaises(exceptions.PoolIsLocked): await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb) - mock_is_open.assert_called_once_with() mock__set_up_args_queue.assert_not_called() mock__queue_consumer.assert_not_awaited() mock_flag_set.assert_not_called() - mock_is_open.reset_mock() - - mock_is_open.return_value = True + self.task_pool._locked = True self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)) mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks) mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars, diff --git a/usage/USAGE.md b/usage/USAGE.md index 08cddfe..a496cdb 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -7,13 +7,13 @@ The minimum required setup is a "worker" coroutine function that can do somethin The following demo code enables full log output first for additional clarity. It is complete and should work as is. ### Code + ```python import logging import asyncio from asyncio_taskpool.pool import SimpleTaskPool - logging.getLogger().setLevel(logging.NOTSET) logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler()) @@ -38,7 +38,7 @@ async def main() -> None: await pool.start() # launches work task 3 await asyncio.sleep(1.5) # lets the tasks work for a bit pool.stop(2) # cancels tasks 3 and 2 - pool.close() # required for the last line + pool.lock() # required for the last line await pool.gather() # awaits all tasks, then flushes the pool @@ -60,7 +60,7 @@ did 1 did 1 did 1 did 0 -SimpleTaskPool-0 is closed! +SimpleTaskPool-0 is locked! Cancelling SimpleTaskPool-0_Task-3 ... Cancelled SimpleTaskPool-0_Task-3 Ended SimpleTaskPool-0_Task-3 @@ -86,13 +86,13 @@ As with the simple example, we need "worker" coroutine functions that can do som The following demo code enables full log output first for additional clarity. It is complete and should work as is. ### Code + ```python import logging import asyncio from asyncio_taskpool.pool import TaskPool - logging.getLogger().setLevel(logging.NOTSET) logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler()) @@ -132,8 +132,8 @@ async def main() -> None: print("Calling `starmap`...") await pool.starmap(other_work, args_list, num_tasks=2) print("`starmap` returned") - # Now we close the pool, so that we can safely await all our tasks. - pool.close() + # Now we lock the pool, so that we can safely await all our tasks. + pool.lock() # Finally, we block, until all tasks have ended. print("Called `gather`") await pool.gather() @@ -177,7 +177,7 @@ work with 150 work with 150 other_work with 8 Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()` -TaskPool-0 is closed! +TaskPool-0 is locked! Started TaskPool-0_Task-3 other_work with 9 `starmap` returned diff --git a/usage/example_server.py b/usage/example_server.py index 0478f43..6ba867a 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -48,12 +48,12 @@ async def main() -> None: control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() # We block until `.task_done()` has been called once by our workers for every item placed into the queue. await q.join() - # Since we don't need any "work" done anymore, we can close our control server by cancelling the task. + # Since we don't need any "work" done anymore, we can lock our control server by cancelling the task. control_server_task.cancel() # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # we can now safely cancel their tasks. pool.stop_all() - pool.close() + pool.lock() # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # we just silently collect their exceptions along with their return values.