Compare commits

..

7 Commits

6 changed files with 577 additions and 98 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.0.2 version = 0.1.5
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -24,3 +24,7 @@ class InvalidTaskID(PoolException):
class PoolStillOpen(PoolException): class PoolStillOpen(PoolException):
pass pass
class NotCoroutine(PoolException):
pass

View File

@ -0,0 +1,29 @@
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
await q.join()

View File

@ -1,13 +1,16 @@
import logging import logging
from asyncio import gather from asyncio import gather
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore from asyncio.locks import Event, Semaphore
from asyncio.queues import Queue, QueueEmpty
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from functools import partial
from math import inf from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions from . import exceptions
from .helpers import execute_optional, star_function, join_queue
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
@ -37,8 +40,7 @@ class BaseTaskPool:
self._num_ended: int = 0 self._num_ended: int = 0
self._idx: int = self._add_pool(self) self._idx: int = self._add_pool(self)
self._name: str = name self._name: str = name
self._all_tasks_known_flag: Event = Event() self._before_gathering: List[Awaitable] = []
self._all_tasks_known_flag.set()
self._interrupt_flag: Event = Event() self._interrupt_flag: Event = Event()
log.debug("%s initialized", str(self)) log.debug("%s initialized", str(self))
@ -135,7 +137,7 @@ class BaseTaskPool:
self._cancelled[task_id] = self._running.pop(task_id) self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1 self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None: async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
""" """
@ -157,7 +159,7 @@ class BaseTaskPool:
self._num_ended += 1 self._num_ended += 1
self._enough_room.release() self._enough_room.release()
log.info("Ended %s", self._task_name(task_id)) log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any: cancel_callback: CancelCallbackT = None) -> Any:
@ -199,14 +201,17 @@ class BaseTaskPool:
If `True`, even if the pool is closed, the task will still be started. If `True`, even if the pool is closed, the task will still be started.
end_callback (optional): end_callback (optional):
A callback to execute after the task has ended. A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument. It is run with the task's ID as its only positional argument.
cancel_callback (optional): cancel_callback (optional):
A callback to execute after cancellation of the task. A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument. It is run with the task's ID as its only positional argument.
Raises: Raises:
`asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`. `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
""" """
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed): if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire() await self._enough_room.acquire()
@ -292,7 +297,8 @@ class BaseTaskPool:
return_exceptions (optional): Passed directly into `gather`. return_exceptions (optional): Passed directly into `gather`.
""" """
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {} self._ended.clear()
self._cancelled.clear()
if self._interrupt_flag.is_set(): if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results
@ -325,76 +331,308 @@ class BaseTaskPool:
""" """
if self._open: if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait() await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {} self._ended.clear()
self._cancelled.clear()
self._running.clear()
self._before_gathering.clear()
if self._interrupt_flag.is_set(): if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):
"""
General task pool class.
Attempts to somewhat emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
Tasks in the pool can all belong to the same coroutine function,
but they can also come from any number of different and unrelated coroutine functions.
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
Adding tasks blocks **only if** the pool is full at that moment.
"""
async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
This method blocks, **only if** the pool is full.
Args:
func:
The coroutine function to be run as a task within the task pool.
args (optional):
The positional arguments to pass into the function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the task's ID as its only positional argument.
Returns:
The newly spawned task's ID within the pool.
"""
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)) """
Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool.
Each coroutine looks like `func(*args, **kwargs)`.
@staticmethod This method blocks, **only if** there is not enough room in the pool for the desired number of new tasks.
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
args (optional):
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Returns:
The newly spawned tasks' IDs within the pool as a list of integers.
Raises:
`NotCoroutine` if `func` is not a coroutine function.
`PoolIsClosed` if the pool has been closed already.
"""
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
assert isinstance(ids, list)
return ids
async def _queue_producer(self, q: Queue, args_iter: Iterator[Any]) -> None:
"""
Keeps the arguments queue from `_map()` full as long as the iterator has elements.
If the `_interrupt_flag` gets set, the loop ends prematurely.
Args:
q:
The queue of function arguments to consume for starting the next task.
args_iter:
The iterator of function arguments to put into the queue.
"""
for arg in args_iter:
if self._interrupt_flag.is_set():
break
await q.put(arg) # This blocks as long as the queue is full.
async def _queue_consumer(self, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`.
Partially constructs the `_queue_callback` function with the same arguments.
Args:
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional):
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument.
"""
try: try:
arg = next(args_iter) arg = q.get_nowait()
except StopIteration: except QueueEmpty:
return return
if arg_stars == 0: try:
return func(arg) await self._start_task(
if arg_stars == 1: star_function(func, arg, arg_stars=arg_stars),
return func(*arg) ignore_closed=True,
if arg_stars == 2: end_callback=partial(TaskPool._queue_callback, self, q=q, func=func, arg_stars=arg_stars,
return func(**arg) end_callback=end_callback, cancel_callback=cancel_callback),
raise ValueError cancel_callback=cancel_callback
)
finally:
q.task_done()
async def _queue_callback(self, task_id: int, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around an end callback function passed into the `_map()` method.
Triggers the next `_queue_consumer` with the same arguments.
Args:
task_id:
The ID of the ending task.
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional):
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the `task_id` as its only positional argument.
"""
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,))
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
"""
Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
is started as a separate task with the arguments queue and and iterator of the remaining arguments.
Args:
args_iter:
The iterable of function arguments passed into `_map()` to use for creating the new tasks.
num_tasks:
The maximum number of the new tasks to run concurrently that was passed into `_map()`.
Returns:
The newly created and filled arguments queue for spawning new tasks.
"""
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
args_iter = iter(args_iter)
try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks):
args_queue.put_nowait(next(args_iter))
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
pass
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# Since the queue is full already, it will wait until one of the tasks in the first batch ends,
# before putting the next item in it.
create_task(self._queue_producer(args_queue, args_iter))
return args_queue
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
and referring to "group size" rather than chunk/batch size.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
if self._all_tasks_known_flag.is_set(): This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
self._all_tasks_known_flag.clear()
args_iter = iter(args_iter)
async def _start_next_coroutine() -> bool: It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None or self._interrupt_flag.is_set():
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None: Args:
await _start_next_coroutine() func:
await _execute_function(end_callback, args=(task_id, )) The coroutine function to use for spawning the new tasks within the task pool.
args_iter:
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional):
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
num_tasks (optional):
The maximum number of the new tasks to run concurrently.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
for _ in range(num_tasks): Raises:
reached_end = await _start_next_coroutine() `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed.
if reached_end: """
break if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
args_queue = self._set_up_args_queue(args_iter, num_tasks)
for _ in range(args_queue.qsize()):
# This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func,
arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks, """
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable.
Once the first batch of tasks has started to run, this method returns.
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool)
consuming the next element from the arguments iterable.
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number
of tasks from this call concurrently running within the pool.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
arg_iter:
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
num_tasks (optional):
The maximum number of the new tasks to run concurrently.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Raises:
`PoolIsClosed` if the pool has been closed.
`NotCoroutine` if `func` is not a coroutine function.
"""
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
positional arguments to the function.
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`.
"""
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
keyword-arguments to the function.
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`.
"""
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
@ -403,6 +641,8 @@ class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: name: str = None) -> None:
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
self._args: ArgsT = args self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
@ -437,13 +677,3 @@ class SimpleTaskPool(BaseTaskPool):
def stop_all(self) -> List[int]: def stop_all(self) -> List[int]:
return self.stop(self.size) return self.stop(self.size)
async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
if kwargs is None:
kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -1,10 +1,15 @@
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable

View File

@ -1,7 +1,9 @@
import asyncio import asyncio
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.queues import Queue
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
@ -14,7 +16,12 @@ class TestException(Exception):
pass pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): class CommonTestCase(IsolatedAsyncioTestCase):
TEST_CLASS: Type[pool.BaseTaskPool] = pool.BaseTaskPool
TEST_POOL_SIZE: int = 420
TEST_POOL_NAME: str = 'test123'
task_pool: pool.BaseTaskPool
log_lvl: int log_lvl: int
@classmethod @classmethod
@ -26,35 +33,38 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl) pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def get_task_pool_init_params(self) -> dict:
self._pools = getattr(pool.BaseTaskPool, '_pools') return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
# These three methods are called during initialization, so we mock them by default during setup def setUp(self) -> None:
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') self._pools = self.TEST_CLASS._pools
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock) # These three methods are called during initialization, so we mock them by default during setup:
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
self.pool_size_patcher = patch.object(self.TEST_CLASS, 'pool_size', new_callable=PropertyMock)
self.dunder_str_patcher = patch.object(self.TEST_CLASS, '__str__')
self.mock__add_pool = self._add_pool_patcher.start() self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start() self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start() self.mock___str__ = self.dunder_str_patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123 self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar' self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters: self.task_pool = self.TEST_CLASS(**self.get_task_pool_init_params())
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None: def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools) self.TEST_CLASS._pools.clear()
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.__str___patcher.stop() self.dunder_str_patcher.stop()
class BaseTaskPoolTestCase(CommonTestCase):
def test__add_pool(self): def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools) self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool) output = pool.BaseTaskPool._add_pool(self.task_pool)
self.assertEqual(0, output) self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self): def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
@ -66,27 +76,26 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(0, self.task_pool._num_cancelled) self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended) self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
def test___str__(self): def test___str__(self):
self.__str___patcher.stop() self.dunder_str_patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.TEST_POOL_NAME}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None) self.task_pool._name = None
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self): def test_pool_size(self):
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size self.task_pool._pool_size = self.TEST_POOL_SIZE
self.assertEqual(self.test_pool_size, self.task_pool.pool_size) self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.task_pool.pool_size = -1 self.task_pool.pool_size = -1
@ -123,9 +132,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, '_execute_function') @patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3 self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task self.task_pool._running[task_id] = mock_task
@ -134,11 +143,11 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, '_execute_function') @patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3 self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
@ -151,9 +160,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 1, self.task_pool._num_ended) self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value) self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock() mock__task_name.reset_mock()
mock__execute_function.reset_mock() mock_execute_optional.reset_mock()
# End cancelled task: # End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
@ -163,7 +172,7 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 2, self.task_pool._num_ended) self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value) self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending') @patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation') @patch.object(pool.BaseTaskPool, '_task_cancellation')
@ -213,14 +222,27 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
mock_create_task.return_value = mock_task = MagicMock() mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock() mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123 self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
mock_is_open.return_value = ignore_closed = False mock_is_open.return_value = ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed, await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter) self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)
@ -231,8 +253,10 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
reset_mocks() reset_mocks()
ignore_closed = True ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed, output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output) self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count]) self.assertEqual(mock_task, self.task_pool._running[count])
@ -246,11 +270,13 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._enough_room._value = room self.task_pool._enough_room._value = room
del self.task_pool._running[count] del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException() mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e: with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed, await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e) self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)
@ -325,11 +351,11 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertFalse(self.task_pool._open) self.assertFalse(self.task_pool._open)
async def test_gather(self): async def test_gather(self):
mock_wait = AsyncMock()
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
test_exception = TestException() test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR) mock_running_func = AsyncMock(return_value=BAR)
mock_queue_join = AsyncMock()
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
self.task_pool._ended = ended = {123: mock_ended_func()} self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()} self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()} self.task_pool._running = running = {789: mock_running_func()}
@ -341,25 +367,210 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertDictEqual(self.task_pool._ended, ended) self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled) self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running) self.assertDictEqual(self.task_pool._running, running)
self.assertListEqual(self.task_pool._before_gathering, before_gather)
self.assertTrue(self.task_pool._interrupt_flag.is_set()) self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_not_awaited()
self.task_pool._open = False self.task_pool._open = False
def check_assertions() -> None: def check_assertions(output) -> None:
self.assertListEqual([FOO, test_exception, BAR], output) self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT) self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_awaited_once_with()
output = await self.task_pool.gather(return_exceptions=True) check_assertions(await self.task_pool.gather(return_exceptions=True))
check_assertions()
mock_wait.reset_mock()
self.task_pool._before_gathering = [mock_queue_join()]
self.task_pool._ended = {123: mock_ended_func()} self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()} self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()} self.task_pool._running = {789: mock_running_func()}
output = await self.task_pool.gather(return_exceptions=True) check_assertions(await self.task_pool.gather(return_exceptions=True))
check_assertions()
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 12345
mock_awaitable = MagicMock()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
mock_func.reset_mock()
mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_apply_one')
async def test_apply(self, mock__apply_one: AsyncMock):
mock__apply_one.return_value = mock_id = 67890
mock_func, num = MagicMock(), 3
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id]
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
async def test__queue_producer(self):
mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put)
args = (FOO, BAR, 123)
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_has_awaits([call(arg) for arg in args])
mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial')
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
mock_partial: MagicMock):
mock_partial.return_value = queue_callback = 'not really'
mock_star_function.return_value = awaitable = 'totally an awaitable'
q, arg = Queue(), 420.69
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
@patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer')
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
task_id, mock_q = 420, MagicMock()
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_func, stars, end_cb, cancel_cb))
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
mock_create_task: MagicMock, mock_iter: MagicMock):
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
mock_join_queue.return_value = mock_join = 'awaitable'
mock_iter.return_value = args_iter = iter(args)
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(num_tasks, output_q.qsize())
for arg in args[:num_tasks]:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
for arg in args[num_tasks:]:
self.assertEqual(arg, next(args_iter))
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
self.task_pool._before_gathering.clear()
mock_join_queue.reset_mock()
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
@patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock)
async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock,
mock__queue_consumer: AsyncMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
mock_is_open.return_value = False
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock_is_open.assert_called_once_with()
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_is_open.reset_mock()
mock_is_open.return_value = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)])
@patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock):
mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)