renamed "closing" a pool to "locking" it

This commit is contained in:
Daniil Fajnberg 2022-02-08 23:09:33 +01:00
parent 012c8ac639
commit bc9d2f243e
6 changed files with 88 additions and 84 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.1.7
version = 0.2.0
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -2,7 +2,7 @@ class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
class PoolIsLocked(PoolException):
pass
@ -22,7 +22,7 @@ class InvalidTaskID(PoolException):
pass
class PoolStillOpen(PoolException):
class PoolStillUnlocked(PoolException):
pass

View File

@ -31,7 +31,7 @@ class BaseTaskPool:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True
self._locked: bool = False
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
@ -71,9 +71,21 @@ class BaseTaskPool:
self._pool_size = value
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open
def is_locked(self) -> bool:
"""Returns `True` if more the pool has been locked (see below)."""
return self._locked
def lock(self) -> None:
"""Disallows any more tasks to be started in the pool."""
if not self._locked:
self._locked = True
log.info("%s is locked!", str(self))
def unlock(self) -> None:
"""Allows new tasks to be started in the pool."""
if self._locked:
self._locked = False
log.info("%s was unlocked.", str(self))
@property
def num_running(self) -> int:
@ -187,7 +199,7 @@ class BaseTaskPool:
finally:
await self._task_ending(task_id, custom_callback=end_callback)
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
"""
Starts a coroutine as a new task in the pool.
@ -197,8 +209,8 @@ class BaseTaskPool:
Args:
awaitable:
The actual coroutine to be run within the task pool.
ignore_closed (optional):
If `True`, even if the pool is closed, the task will still be started.
ignore_lock (optional):
If `True`, even if the pool is locked, the task will still be started.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
@ -208,12 +220,12 @@ class BaseTaskPool:
Raises:
`asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
`asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`.
"""
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks")
await self._enough_room.acquire()
task_id = self._counter
self._counter += 1
@ -303,16 +315,11 @@ class BaseTaskPool:
self._interrupt_flag.clear()
return results
def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
The `lock()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
@ -327,10 +334,10 @@ class BaseTaskPool:
return_exceptions (optional): Passed directly into `gather`.
Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
`asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet.
"""
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
if not self._locked:
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
@ -415,7 +422,7 @@ class TaskPool(BaseTaskPool):
Raises:
`NotCoroutine` if `func` is not a coroutine function.
`PoolIsClosed` if the pool has been closed already.
`PoolIsLocked` if the pool has been locked already.
"""
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
@ -468,7 +475,7 @@ class TaskPool(BaseTaskPool):
try:
await self._start_task(
star_function(func, arg, arg_stars=arg_stars),
ignore_closed=True,
ignore_lock=True,
end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started,
func=func, arg_stars=arg_stars, end_callback=end_callback,
cancel_callback=cancel_callback),
@ -579,10 +586,10 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed.
`asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked.
"""
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if not self._locked:
raise exceptions.PoolIsLocked("Cannot start new tasks")
args_queue = self._set_up_args_queue(args_iter, num_tasks)
# We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
# `_queue_callback` triggered by one or more of them.
@ -628,7 +635,7 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Raises:
`PoolIsClosed` if the pool has been closed.
`PoolIsLocked` if the pool has been locked.
`NotCoroutine` if `func` is not a coroutine function.
"""
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,

View File

@ -68,7 +68,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertTrue(self.task_pool._open)
self.assertFalse(self.task_pool._locked)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
@ -103,9 +103,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self):
self.task_pool._open = FOO
self.assertEqual(FOO, self.task_pool.is_open)
def test_is_locked(self):
self.task_pool._locked = FOO
self.assertEqual(FOO, self.task_pool.is_locked)
def test_lock(self):
assert not self.task_pool._locked
self.task_pool.lock()
self.assertTrue(self.task_pool._locked)
self.task_pool.lock()
self.assertTrue(self.task_pool._locked)
def test_unlock(self):
self.task_pool._locked = True
self.task_pool.unlock()
self.assertFalse(self.task_pool._locked)
self.task_pool.unlock()
self.assertFalse(self.task_pool._locked)
def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz']
@ -211,11 +225,9 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock,
mock_create_task: MagicMock):
def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
@ -226,31 +238,27 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
def check_nothing_changed() -> None:
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
check_nothing_changed()
mock_is_open.return_value = ignore_closed = False
self.task_pool._locked = True
ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed):
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
check_nothing_changed()
ignore_closed = True
mock_awaitable = mock_coroutine()
@ -261,7 +269,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@ -280,7 +287,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@ -345,11 +351,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
@ -361,8 +362,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert self.task_pool._open
with self.assertRaises(exceptions.PoolStillOpen):
assert not self.task_pool._locked
with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
@ -370,7 +371,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertListEqual(self.task_pool._before_gathering, before_gather)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
self.task_pool._open = False
self.task_pool._locked = True
def check_assertions(output) -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
@ -449,7 +450,7 @@ class TaskPoolTestCase(CommonTestCase):
mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True,
mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
@ -459,7 +460,7 @@ class TaskPoolTestCase(CommonTestCase):
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
@ -528,9 +529,8 @@ class TaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'Event')
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
@patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock)
async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock,
mock__queue_consumer: AsyncMock, mock_event_cls: MagicMock):
async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
mock_event_cls: MagicMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_flag_set = MagicMock()
@ -540,17 +540,14 @@ class TaskPoolTestCase(CommonTestCase):
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
mock_is_open.return_value = False
with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._locked = False
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock_is_open.assert_called_once_with()
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_flag_set.assert_not_called()
mock_is_open.reset_mock()
mock_is_open.return_value = True
self.task_pool._locked = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,

View File

@ -7,13 +7,13 @@ The minimum required setup is a "worker" coroutine function that can do somethin
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
```python
import logging
import asyncio
from asyncio_taskpool.pool import SimpleTaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
@ -38,7 +38,7 @@ async def main() -> None:
await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line
pool.lock() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool
@ -60,7 +60,7 @@ did 1
did 1
did 1
did 0
SimpleTaskPool-0 is closed!
SimpleTaskPool-0 is locked!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
@ -86,13 +86,13 @@ As with the simple example, we need "worker" coroutine functions that can do som
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
```python
import logging
import asyncio
from asyncio_taskpool.pool import TaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
@ -132,8 +132,8 @@ async def main() -> None:
print("Calling `starmap`...")
await pool.starmap(other_work, args_list, num_tasks=2)
print("`starmap` returned")
# Now we close the pool, so that we can safely await all our tasks.
pool.close()
# Now we lock the pool, so that we can safely await all our tasks.
pool.lock()
# Finally, we block, until all tasks have ended.
print("Called `gather`")
await pool.gather()
@ -177,7 +177,7 @@ work with 150
work with 150
other_work with 8
Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
TaskPool-0 is closed!
TaskPool-0 is locked!
Started TaskPool-0_Task-3
other_work with 9
`starmap` returned

View File

@ -48,12 +48,12 @@ async def main() -> None:
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()
# Since we don't need any "work" done anymore, we can close our control server by cancelling the task.
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.
pool.stop_all()
pool.close()
pool.lock()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values.