From 3c69740c8d3bfd4d4b603b18e6c94e17231f6050 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Tue, 8 Feb 2022 12:09:21 +0100 Subject: [PATCH] factored out the queue setup of `TaskPool._map` --- setup.cfg | 2 +- src/asyncio_taskpool/helpers.py | 5 ++++ src/asyncio_taskpool/pool.py | 42 ++++++++++++++++++--------------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/setup.cfg b/setup.cfg index 873a05b..b9b0a3e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.1.2 +version = 0.1.3 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py index d009a1f..35ccaa4 100644 --- a/src/asyncio_taskpool/helpers.py +++ b/src/asyncio_taskpool/helpers.py @@ -1,4 +1,5 @@ from asyncio.coroutines import iscoroutinefunction +from asyncio.queues import Queue from typing import Any, Optional from .types import T, AnyCallableT, ArgsT, KwArgsT @@ -22,3 +23,7 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: if arg_stars == 2: return function(**arg) raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") + + +async def join_queue(q: Queue) -> None: + await q.join() diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index e0d4f56..2c08b39 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -10,7 +10,7 @@ from math import inf from typing import Any, Awaitable, Dict, Iterable, Iterator, List from . import exceptions -from .helpers import execute_optional, star_function +from .helpers import execute_optional, star_function, join_queue from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT @@ -498,6 +498,26 @@ class TaskPool(BaseTaskPool): await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) await execute_optional(end_callback, args=(task_id,)) + def _fill_args_queue(self, q: Queue, args_iter: ArgsT, num_tasks: int) -> int: + args_iter = iter(args_iter) + try: + # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of + # tasks, which will be at most `num_tasks` (meaning the queue will be full). + for i in range(num_tasks): + q.put_nowait(next(args_iter)) + except StopIteration: + # If we get here, this means that the number of elements in the arguments iterator was less than the + # specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue. + # The `_queue_producer` won't be necessary, since we already put all the elements in the queue. + num_tasks = q.qsize() + else: + # There may be more elements in the arguments iterator, so we need the `_queue_producer`. + # It will have exclusive access to the `args_iter` from now on. + # If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting + # the next item in it. + create_task(self._queue_producer(q, args_iter)) + return num_tasks + async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: """ @@ -533,24 +553,8 @@ class TaskPool(BaseTaskPool): if not self.is_open: raise exceptions.PoolIsClosed("Cannot start new tasks") args_queue = Queue(maxsize=num_tasks) - self._before_gathering.append(args_queue.join()) - args_iter = iter(args_iter) - try: - # Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of - # tasks, which will be at most `num_tasks` (meaning the queue will be full). - for i in range(num_tasks): - args_queue.put_nowait(next(args_iter)) - except StopIteration: - # If we get here, this means that the number of elements in the arguments iterator was less than the - # specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue. - # The `_queue_producer` won't be necessary, since we already put all the elements in the queue. - num_tasks = args_queue.qsize() - else: - # There may be more elements in the arguments iterator, so we need the `_queue_producer`. - # It will have exclusive access to the `args_iter` from now on. - # If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting - # the next item in it. - create_task(self._queue_producer(args_queue, args_iter)) + self._before_gathering.append(join_queue(args_queue)) + num_tasks = self._fill_args_queue(args_queue, args_iter, num_tasks) for _ in range(num_tasks): # This is where blocking can occur, if the pool is full. await self._queue_consumer(args_queue, func,