reworked the _map method to make use of an arguments queue; added docstrings to the TaskPool methods

This commit is contained in:
Daniil Fajnberg 2022-02-07 23:17:39 +01:00
parent 99ece436de
commit 16eda31648
2 changed files with 131 additions and 67 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.0.3
version = 0.1.1
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -3,6 +3,7 @@ from asyncio import gather
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore
from asyncio.queues import Queue, QueueEmpty
from asyncio.tasks import Task, create_task
from functools import partial
from math import inf
@ -39,8 +40,7 @@ class BaseTaskPool:
self._num_ended: int = 0
self._idx: int = self._add_pool(self)
self._name: str = name
self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set()
self._before_gathering: List[Awaitable] = []
self._interrupt_flag: Event = Event()
log.debug("%s initialized", str(self))
@ -330,7 +330,7 @@ class BaseTaskPool:
"""
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait()
await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {}
@ -418,22 +418,72 @@ class TaskPool(BaseTaskPool):
assert isinstance(ids, list)
return ids
async def _next_callback(self, task_id: int, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
async def _queue_producer(self, q: Queue, args_iter: Iterator[Any]) -> None:
"""
Keeps the arguments queue from `_map()` full as long as the iterator has elements.
If the `_interrupt_flag` gets set, the loop ends prematurely.
Args:
q:
The queue of function arguments to consume for starting the next task.
args_iter:
The iterator of function arguments to put into the queue.
"""
for arg in args_iter:
if self._interrupt_flag.is_set():
break
await q.put(arg) # This blocks as long as the queue is full.
async def _queue_consumer(self, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`.
Partially constructs the `_queue_callback` function with the same arguments.
Args:
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional):
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument.
"""
try:
arg = q.get_nowait()
except QueueEmpty:
return
try:
await self._start_task(
star_function(func, arg, arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._queue_callback, self, q=q, func=func, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
finally:
q.task_done()
async def _queue_callback(self, task_id: int, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around an end callback function passed into the `_map()` method.
To be used in conjunction with `_start_next_task()` to simulate a queue of coroutines to be started as tasks
in the pool, whenever the `_map()` method is called.
Triggers the next `_queue_consumer` with the same arguments.
Args:
task_id:
The ID of the ending task.
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
args_iter:
The iterator of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional):
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the `task_id` as its only positional argument.
@ -441,62 +491,19 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the `task_id` as its only positional argument.
"""
reached_end = await self._start_next_task(func, args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,))
async def _start_next_task(self, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> bool:
"""
Starts a new task in the pool using the next element from the arguments iterator.
Helper used in conjunction with the `_next_callback()` wrapper in the `_map()` method.
Args:
func:
The coroutine function to use for spawning the tasks within the task pool.
args_iter:
The iterator of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional):
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
end_callback (optional):
The callback specified to execute after a task (and the next one) has ended.
It is run with the Task's ID as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the Task's ID as its only positional argument.
Returns:
`True` if the end of `args_iter` has been reached or the `_interrupt_flag` has been set; `False` otherwise.
"""
if self._interrupt_flag.is_set():
return True
try:
await self._start_task(
star_function(func, next(args_iter), arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._next_callback, self, func=func, args_iter=args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
except StopIteration:
return True
return False
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in chunks.
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
This method blocks, **only if** there is not enough room in the pool for the first chunk of new tasks.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
It clears the internal `_all_tasks_known_flag` until the end of the iterable of arguments has been reached,
then sets it.
TODO: This is wrong because it may interfere with another call to this method.
Consider rebuilding this entire method using a `asyncio.Queue` instead of convoluted callbacks.
Then instead of the `_all_tasks_known_flag` the pool's `gather` can wait on the Queue's `join`...
It sets up an internal queue which is filled while consuming the arguments iterable.
The queue's `join()` method is added to the pool's `_before_gathering` list.
Args:
func:
@ -519,27 +526,84 @@ class TaskPool(BaseTaskPool):
"""
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear()
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(args_queue.join())
args_iter = iter(args_iter)
try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks):
args_queue.put_nowait(next(args_iter))
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
num_tasks = args_queue.qsize()
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting
# the next item in it.
create_task(self._queue_producer(args_queue, args_iter))
for _ in range(num_tasks):
reached_end = await self._start_next_task(func, args_iter, arg_stars, end_callback, cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
break
# This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func,
arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
"""
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable.
Once the first batch of tasks has started to run, this method returns.
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool)
consuming the next element from the arguments iterable.
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number
of tasks from this call concurrently running within the pool.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
arg_iter:
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
num_tasks (optional):
The maximum number of the new tasks to run concurrently.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Raises:
`PoolIsClosed` if the pool has been closed.
`NotCoroutine` if `func` is not a coroutine function.
"""
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
positional arguments to the function.
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`.
"""
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
keyword-arguments to the function.
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`.
"""
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)