generated from daniil-berg/boilerplate-py
fixed potential race cond. gathering meta tasks
This commit is contained in:
@ -20,7 +20,6 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
||||
|
||||
|
||||
from asyncio.coroutines import iscoroutinefunction
|
||||
from asyncio.queues import Queue
|
||||
from importlib import import_module
|
||||
from inspect import getdoc
|
||||
from typing import Any, Optional, Union
|
||||
@ -86,11 +85,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
||||
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
|
||||
|
||||
|
||||
async def join_queue(q: Queue) -> None:
|
||||
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
|
||||
await q.join()
|
||||
|
||||
|
||||
def get_first_doc_line(obj: object) -> str:
|
||||
"""Takes an object and returns the first (non-empty) line of its docstring."""
|
||||
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||
|
@ -42,7 +42,7 @@ from . import exceptions
|
||||
from .queue_context import Queue
|
||||
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
|
||||
from .internals.group_register import TaskGroupRegister
|
||||
from .internals.helpers import execute_optional, star_function, join_queue
|
||||
from .internals.helpers import execute_optional, star_function
|
||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||
|
||||
|
||||
@ -82,8 +82,7 @@ class BaseTaskPool:
|
||||
self._tasks_cancelled: Dict[int, Task] = {}
|
||||
self._tasks_ended: Dict[int, Task] = {}
|
||||
|
||||
# These next three attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._before_gathering: List[Awaitable] = []
|
||||
# These next two attributes act as synchronisation primitives necessary for managing the pool.
|
||||
self._enough_room: Semaphore = Semaphore()
|
||||
self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
|
||||
|
||||
@ -468,13 +467,11 @@ class BaseTaskPool:
|
||||
"""
|
||||
if not self._locked:
|
||||
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
|
||||
await gather(*self._before_gathering)
|
||||
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._tasks_ended.clear()
|
||||
self._tasks_cancelled.clear()
|
||||
self._tasks_running.clear()
|
||||
self._before_gathering.clear()
|
||||
self._closed = True
|
||||
|
||||
|
||||
@ -499,7 +496,7 @@ class TaskPool(BaseTaskPool):
|
||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
||||
super().__init__(pool_size=pool_size, name=name)
|
||||
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
|
||||
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
|
||||
# meta tasks that are/were running in the context of that group, and a bucket for cancelled meta tasks.
|
||||
self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
|
||||
self._meta_tasks_cancelled: Set[Task] = set()
|
||||
|
||||
@ -592,11 +589,11 @@ class TaskPool(BaseTaskPool):
|
||||
Args:
|
||||
return_exceptions (optional): Passed directly into `gather`.
|
||||
"""
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
|
||||
return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
await super().flush(return_exceptions=return_exceptions)
|
||||
|
||||
async def gather_and_close(self, return_exceptions: bool = False):
|
||||
"""
|
||||
@ -607,9 +604,8 @@ class TaskPool(BaseTaskPool):
|
||||
The `lock()` method must have been called prior to this.
|
||||
|
||||
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
|
||||
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed,
|
||||
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
|
||||
this, make sure to call :meth:`cancel_all` prior to this.
|
||||
tasks launched by methods such as :meth:`map`, which end by themselves, only once the arguments iterator is
|
||||
fully consumed (which may not even be possible). To avoid this, make sure to call :meth:`cancel_all` first.
|
||||
|
||||
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
||||
the callbacks registered for a task blocks for whatever reason.
|
||||
@ -620,15 +616,12 @@ class TaskPool(BaseTaskPool):
|
||||
Raises:
|
||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
||||
"""
|
||||
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`).
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
not_cancelled_meta_tasks = set()
|
||||
while self._group_meta_tasks_running:
|
||||
_, meta_tasks = self._group_meta_tasks_running.popitem()
|
||||
not_cancelled_meta_tasks.update(meta_tasks)
|
||||
not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
|
||||
with suppress(CancelledError):
|
||||
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
|
||||
self._meta_tasks_cancelled.clear()
|
||||
self._group_meta_tasks_running.clear()
|
||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
|
||||
@ -789,7 +782,7 @@ class TaskPool(BaseTaskPool):
|
||||
# The following line blocks **only if** the number of running tasks spawned by this method has reached the
|
||||
# specified maximum as determined in :meth:`_map`.
|
||||
await map_semaphore.acquire()
|
||||
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
|
||||
# We await the queue's `get()` coroutine and ensure that its `item_processed()` method is called.
|
||||
async with arg_queue as next_arg:
|
||||
if next_arg is self._QUEUE_END_SENTINEL:
|
||||
# The :meth:`_queue_producer` either reached the last argument or was cancelled.
|
||||
@ -797,6 +790,9 @@ class TaskPool(BaseTaskPool):
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
map_semaphore.release()
|
||||
return
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
@ -856,13 +852,10 @@ class TaskPool(BaseTaskPool):
|
||||
async with group_reg:
|
||||
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
|
||||
# `_queue_producer()`; that way an argument
|
||||
# TODO: Perhaps this can be simplified to just one meta-task with no need for a queue.
|
||||
# The limiting factor honoring the group size is already the semaphore in the queue consumer;
|
||||
# Try to write this without a producer, instead consuming the `arg_iter` directly.
|
||||
arg_queue = Queue(maxsize=group_size)
|
||||
# TODO: This is the wrong thing to await before gathering!
|
||||
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
|
||||
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
|
||||
# the `join` call before the arguments iterator was fully consumed.
|
||||
# Probably the queue producer task should be awaited before gathering instead.
|
||||
self._before_gathering.append(join_queue(arg_queue))
|
||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||
# Start the producer and consumer meta tasks.
|
||||
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
|
||||
|
Reference in New Issue
Block a user