renamed "closing" a pool to "locking" it

This commit is contained in:
2022-02-08 23:09:33 +01:00
parent 012c8ac639
commit bc9d2f243e
6 changed files with 88 additions and 84 deletions

View File

@ -2,7 +2,7 @@ class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
class PoolIsLocked(PoolException):
pass
@ -22,7 +22,7 @@ class InvalidTaskID(PoolException):
pass
class PoolStillOpen(PoolException):
class PoolStillUnlocked(PoolException):
pass

View File

@ -31,7 +31,7 @@ class BaseTaskPool:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True
self._locked: bool = False
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
@ -71,9 +71,21 @@ class BaseTaskPool:
self._pool_size = value
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open
def is_locked(self) -> bool:
"""Returns `True` if more the pool has been locked (see below)."""
return self._locked
def lock(self) -> None:
"""Disallows any more tasks to be started in the pool."""
if not self._locked:
self._locked = True
log.info("%s is locked!", str(self))
def unlock(self) -> None:
"""Allows new tasks to be started in the pool."""
if self._locked:
self._locked = False
log.info("%s was unlocked.", str(self))
@property
def num_running(self) -> int:
@ -187,7 +199,7 @@ class BaseTaskPool:
finally:
await self._task_ending(task_id, custom_callback=end_callback)
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
"""
Starts a coroutine as a new task in the pool.
@ -197,8 +209,8 @@ class BaseTaskPool:
Args:
awaitable:
The actual coroutine to be run within the task pool.
ignore_closed (optional):
If `True`, even if the pool is closed, the task will still be started.
ignore_lock (optional):
If `True`, even if the pool is locked, the task will still be started.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
@ -208,12 +220,12 @@ class BaseTaskPool:
Raises:
`asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
`asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`.
"""
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks")
await self._enough_room.acquire()
task_id = self._counter
self._counter += 1
@ -303,16 +315,11 @@ class BaseTaskPool:
self._interrupt_flag.clear()
return results
def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
The `lock()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
@ -327,10 +334,10 @@ class BaseTaskPool:
return_exceptions (optional): Passed directly into `gather`.
Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
`asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet.
"""
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
if not self._locked:
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
@ -415,7 +422,7 @@ class TaskPool(BaseTaskPool):
Raises:
`NotCoroutine` if `func` is not a coroutine function.
`PoolIsClosed` if the pool has been closed already.
`PoolIsLocked` if the pool has been locked already.
"""
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
@ -468,7 +475,7 @@ class TaskPool(BaseTaskPool):
try:
await self._start_task(
star_function(func, arg, arg_stars=arg_stars),
ignore_closed=True,
ignore_lock=True,
end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started,
func=func, arg_stars=arg_stars, end_callback=end_callback,
cancel_callback=cancel_callback),
@ -579,10 +586,10 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed.
`asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked.
"""
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if not self._locked:
raise exceptions.PoolIsLocked("Cannot start new tasks")
args_queue = self._set_up_args_queue(args_iter, num_tasks)
# We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
# `_queue_callback` triggered by one or more of them.
@ -628,7 +635,7 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument.
Raises:
`PoolIsClosed` if the pool has been closed.
`PoolIsLocked` if the pool has been locked.
`NotCoroutine` if `func` is not a coroutine function.
"""
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,