Compare commits

..

10 Commits

11 changed files with 763 additions and 152 deletions

View File

@ -14,6 +14,11 @@ See [USAGE.md](usage/USAGE.md)
Python Version 3.8+, tested on Linux Python Version 3.8+, tested on Linux
## Testing
Install `dev` dependencies or just manually install `coverage` with `pip`.
Run the [`coverage.sh`](coverage.sh) shell script to execute all unit tests and receive the coverage report.
## Building from source ## Building from source
Run `python -m build` Run `python -m build`

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.1.1 version = 0.2.0
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -3,6 +3,6 @@ MSG_BYTES = 1024
CMD_START = 'start' CMD_START = 'start'
CMD_STOP = 'stop' CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all' CMD_STOP_ALL = 'stop_all'
CMD_SIZE = 'size' CMD_NUM_RUNNING = 'num_running'
CMD_FUNC = 'func' CMD_FUNC = 'func'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'

View File

@ -2,7 +2,7 @@ class PoolException(Exception):
pass pass
class PoolIsClosed(PoolException): class PoolIsLocked(PoolException):
pass pass
@ -22,7 +22,7 @@ class InvalidTaskID(PoolException):
pass pass
class PoolStillOpen(PoolException): class PoolStillUnlocked(PoolException):
pass pass

View File

@ -1,4 +1,5 @@
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from typing import Any, Optional from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -22,3 +23,7 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 2: if arg_stars == 2:
return function(**arg) return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
await q.join()

View File

@ -10,7 +10,7 @@ from math import inf
from typing import Any, Awaitable, Dict, Iterable, Iterator, List from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions from . import exceptions
from .helpers import execute_optional, star_function from .helpers import execute_optional, star_function, join_queue
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
@ -31,7 +31,7 @@ class BaseTaskPool:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list.""" """Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._enough_room: Semaphore = Semaphore() self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size self.pool_size = pool_size
self._open: bool = True self._locked: bool = False
self._counter: int = 0 self._counter: int = 0
self._running: Dict[int, Task] = {} self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {}
@ -71,9 +71,21 @@ class BaseTaskPool:
self._pool_size = value self._pool_size = value
@property @property
def is_open(self) -> bool: def is_locked(self) -> bool:
"""Returns `True` if more the pool has not been closed yet.""" """Returns `True` if more the pool has been locked (see below)."""
return self._open return self._locked
def lock(self) -> None:
"""Disallows any more tasks to be started in the pool."""
if not self._locked:
self._locked = True
log.info("%s is locked!", str(self))
def unlock(self) -> None:
"""Allows new tasks to be started in the pool."""
if self._locked:
self._locked = False
log.info("%s was unlocked.", str(self))
@property @property
def num_running(self) -> int: def num_running(self) -> int:
@ -187,7 +199,7 @@ class BaseTaskPool:
finally: finally:
await self._task_ending(task_id, custom_callback=end_callback) await self._task_ending(task_id, custom_callback=end_callback)
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None, async def _start_task(self, awaitable: Awaitable, ignore_lock: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int: cancel_callback: CancelCallbackT = None) -> int:
""" """
Starts a coroutine as a new task in the pool. Starts a coroutine as a new task in the pool.
@ -197,8 +209,8 @@ class BaseTaskPool:
Args: Args:
awaitable: awaitable:
The actual coroutine to be run within the task pool. The actual coroutine to be run within the task pool.
ignore_closed (optional): ignore_lock (optional):
If `True`, even if the pool is closed, the task will still be started. If `True`, even if the pool is locked, the task will still be started.
end_callback (optional): end_callback (optional):
A callback to execute after the task has ended. A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
@ -208,12 +220,12 @@ class BaseTaskPool:
Raises: Raises:
`asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine. `asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`. `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked and `ignore_lock` is `False`.
""" """
if not iscoroutine(awaitable): if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed): if self._locked and not ignore_lock:
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
await self._enough_room.acquire() await self._enough_room.acquire()
task_id = self._counter task_id = self._counter
self._counter += 1 self._counter += 1
@ -297,21 +309,17 @@ class BaseTaskPool:
return_exceptions (optional): Passed directly into `gather`. return_exceptions (optional): Passed directly into `gather`.
""" """
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {} self._ended.clear()
self._cancelled.clear()
if self._interrupt_flag.is_set(): if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results
def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False): async def gather(self, return_exceptions: bool = False):
""" """
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks. Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this. The `lock()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks. Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
@ -326,14 +334,17 @@ class BaseTaskPool:
return_exceptions (optional): Passed directly into `gather`. return_exceptions (optional): Passed directly into `gather`.
Raises: Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet. `asyncio_taskpool.exceptions.PoolStillUnlocked` if the pool has not been locked yet.
""" """
if self._open: if not self._locked:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
await gather(*self._before_gathering) await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {} self._ended.clear()
self._cancelled.clear()
self._running.clear()
self._before_gathering.clear()
if self._interrupt_flag.is_set(): if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results
@ -411,7 +422,7 @@ class TaskPool(BaseTaskPool):
Raises: Raises:
`NotCoroutine` if `func` is not a coroutine function. `NotCoroutine` if `func` is not a coroutine function.
`PoolIsClosed` if the pool has been closed already. `PoolIsLocked` if the pool has been locked already.
""" """
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))) ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions # TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
@ -434,7 +445,7 @@ class TaskPool(BaseTaskPool):
break break
await q.put(arg) # This blocks as long as the queue is full. await q.put(arg) # This blocks as long as the queue is full.
async def _queue_consumer(self, q: Queue, func: CoroutineFunc, arg_stars: int = 0, async def _queue_consumer(self, q: Queue, first_batch_started: Event, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`. Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`.
@ -443,6 +454,9 @@ class TaskPool(BaseTaskPool):
Args: Args:
q: q:
The queue of function arguments to consume for starting the next task. The queue of function arguments to consume for starting the next task.
first_batch_started:
The event flag to wait for, before launching the next consumer.
It can only set by the `_map()` method, which happens after the first batch of task has been started.
func: func:
The coroutine function to use for spawning the tasks within the task pool. The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional): arg_stars (optional):
@ -461,16 +475,18 @@ class TaskPool(BaseTaskPool):
try: try:
await self._start_task( await self._start_task(
star_function(func, arg, arg_stars=arg_stars), star_function(func, arg, arg_stars=arg_stars),
ignore_closed=True, ignore_lock=True,
end_callback=partial(TaskPool._queue_callback, self, q=q, func=func, arg_stars=arg_stars, end_callback=partial(TaskPool._queue_callback, self, q=q, first_batch_started=first_batch_started,
end_callback=end_callback, cancel_callback=cancel_callback), func=func, arg_stars=arg_stars, end_callback=end_callback,
cancel_callback=cancel_callback),
cancel_callback=cancel_callback cancel_callback=cancel_callback
) )
finally: finally:
q.task_done() q.task_done()
async def _queue_callback(self, task_id: int, q: Queue, func: CoroutineFunc, arg_stars: int = 0, async def _queue_callback(self, task_id: int, q: Queue, first_batch_started: Event, func: CoroutineFunc,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: arg_stars: int = 0, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> None:
""" """
Wrapper around an end callback function passed into the `_map()` method. Wrapper around an end callback function passed into the `_map()` method.
Triggers the next `_queue_consumer` with the same arguments. Triggers the next `_queue_consumer` with the same arguments.
@ -480,6 +496,9 @@ class TaskPool(BaseTaskPool):
The ID of the ending task. The ID of the ending task.
q: q:
The queue of function arguments to consume for starting the next task. The queue of function arguments to consume for starting the next task.
first_batch_started:
The event flag to wait for, before launching the next consumer.
It can only set by the `_map()` method, which happens after the first batch of task has been started.
func: func:
The coroutine function to use for spawning the tasks within the task pool. The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional): arg_stars (optional):
@ -491,19 +510,64 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one). The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the `task_id` as its only positional argument. It is run with the `task_id` as its only positional argument.
""" """
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) await first_batch_started.wait()
await self._queue_consumer(q, first_batch_started, func, arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,)) await execute_optional(end_callback, args=(task_id,))
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
"""
Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
is started as a separate task with the arguments queue and and iterator of the remaining arguments.
Args:
args_iter:
The iterable of function arguments passed into `_map()` to use for creating the new tasks.
num_tasks:
The maximum number of the new tasks to run concurrently that was passed into `_map()`.
Returns:
The newly created and filled arguments queue for spawning new tasks.
"""
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
args_iter = iter(args_iter)
try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks):
args_queue.put_nowait(next(args_iter))
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
pass
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# Since the queue is full already, it will wait until one of the tasks in the first batch ends,
# before putting the next item in it.
create_task(self._queue_producer(args_queue, args_iter))
return args_queue
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
""" """
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches. Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
and referring to "group size" rather than chunk/batch size.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable. Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks. This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
It sets up an internal queue which is filled while consuming the arguments iterable. It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
The queue's `join()` method is added to the pool's `_before_gathering` list.
Args: Args:
func: func:
@ -522,33 +586,23 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
Raises: Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed. `asyncio_taskpool.exceptions.PoolIsLocked` if the pool has been locked.
""" """
if not self.is_open: if not self._locked:
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
args_queue = Queue(maxsize=num_tasks) args_queue = self._set_up_args_queue(args_iter, num_tasks)
self._before_gathering.append(args_queue.join()) # We need a flag to ensure that starting all tasks from the first batch here will not be blocked by the
args_iter = iter(args_iter) # `_queue_callback` triggered by one or more of them.
try: # This could happen, e.g. if the pool has just enough room for one more task, but the queue here contains more
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of # than one element, and the pool remains full until after the first task of the first batch ends. Then the
# tasks, which will be at most `num_tasks` (meaning the queue will be full). # callback might trigger the next `_queue_consumer` before this method can, which will keep it blocked.
for i in range(num_tasks): first_batch_started = Event()
args_queue.put_nowait(next(args_iter)) for _ in range(args_queue.qsize()):
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
num_tasks = args_queue.qsize()
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting
# the next item in it.
create_task(self._queue_producer(args_queue, args_iter))
for _ in range(num_tasks):
# This is where blocking can occur, if the pool is full. # This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func, await self._queue_consumer(args_queue, first_batch_started, func,
arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback) arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
# Now the callbacks can immediately trigger more tasks.
first_batch_started.set()
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
@ -581,7 +635,7 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
Raises: Raises:
`PoolIsClosed` if the pool has been closed. `PoolIsLocked` if the pool has been locked.
`NotCoroutine` if `func` is not a coroutine function. `NotCoroutine` if `func` is not a coroutine function.
""" """
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks, await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
@ -609,9 +663,45 @@ class TaskPool(BaseTaskPool):
class SimpleTaskPool(BaseTaskPool): class SimpleTaskPool(BaseTaskPool):
"""
Simplified task pool class.
A `SimpleTaskPool` instance can manage an arbitrary number of concurrent tasks,
but they **must** come from a single coroutine function, called with the same arguments.
The coroutine function and its arguments are defined upon initialization.
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
However, since all tasks come from the same function-arguments-combination, the specificity of the `cancel()` method
is probably unnecessary. Instead, a simpler `stop()` method is introduced.
Adding tasks blocks **only if** the pool is full at that moment.
"""
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: pool_size: int = inf, name: str = None) -> None:
"""
Args:
func:
The function to use for spawning new tasks within the pool.
args (optional):
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into each function call.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
pool_size (optional):
The maximum number of tasks allowed to run concurrently in the pool
name (optional):
An optional name for the pool.
"""
if not iscoroutinefunction(func): if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}") raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
@ -619,32 +709,39 @@ class SimpleTaskPool(BaseTaskPool):
self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._end_callback: EndCallbackT = end_callback self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback self._cancel_callback: CancelCallbackT = cancel_callback
super().__init__(name=name) super().__init__(pool_size=pool_size, name=name)
@property @property
def func_name(self) -> str: def func_name(self) -> str:
"""Returns the name of the coroutine function used in the pool."""
return self._func.__name__ return self._func.__name__
@property
def size(self) -> int:
return self.num_running
async def _start_one(self) -> int: async def _start_one(self) -> int:
"""Starts a single new task within the pool and returns its ID."""
return await self._start_task(self._func(*self._args, **self._kwargs), return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback) end_callback=self._end_callback, cancel_callback=self._cancel_callback)
async def start(self, num: int = 1) -> List[int]: async def start(self, num: int = 1) -> List[int]:
return [await self._start_one() for _ in range(num)] """Starts `num` new tasks within the pool and returns their IDs as a list."""
ids = await gather(*(self._start_one() for _ in range(num)))
assert isinstance(ids, list) # for PyCharm (see above to-do-item)
return ids
def stop(self, num: int = 1) -> List[int]: def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size) """
Cancels `num` running tasks within the pool and returns their IDs as a list.
The tasks are canceled in LIFO order, meaning tasks started later will be stopped before those started earlier.
If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled.
"""
ids = [] ids = []
for i, task_id in enumerate(reversed(self._running)): for i, task_id in enumerate(reversed(self._running)):
if i >= num: if i >= num:
break break # We got the desired number of task IDs, there may well be more tasks left to keep running
ids.append(task_id) ids.append(task_id)
self.cancel(*ids) self.cancel(*ids)
return ids return ids
def stop_all(self) -> List[int]: def stop_all(self) -> List[int]:
return self.stop(self.size) """Cancels all running tasks and returns their IDs as a list."""
return self.stop(self.num_running)

View File

@ -63,8 +63,8 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
writer.write(str(self._pool.stop_all()).encode()) writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None: def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests pool size", self.client_class.__name__) log.debug("%s requests number of running tasks", self.client_class.__name__)
writer.write(str(self._pool.size).encode()) writer.write(str(self._pool.num_running).encode())
def _pool_func(self, writer: StreamWriter) -> None: def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__) log.debug("%s requests pool function", self.client_class.__name__)
@ -83,7 +83,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
self._stop_tasks(writer, arg) self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL: elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer) self._stop_all_tasks(writer)
elif cmd == constants.CMD_SIZE: elif cmd == constants.CMD_NUM_RUNNING:
self._pool_size(writer) self._pool_size(writer)
elif cmd == constants.CMD_FUNC: elif cmd == constants.CMD_FUNC:
self._pool_func(writer) self._pool_func(writer)

67
tests/test_helpers.py Normal file
View File

@ -0,0 +1,67 @@
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock
from asyncio_taskpool import helpers
class HelpersTestCase(IsolatedAsyncioTestCase):
async def test_execute_optional(self):
f, args, kwargs = NonCallableMagicMock(), [1, 2], None
a = [f, args, kwargs] # to avoid IDE nagging
self.assertIsNone(await helpers.execute_optional(*a))
expected_output = 'foo'
f = MagicMock(return_value=expected_output)
output = await helpers.execute_optional(f, args, kwargs)
self.assertEqual(expected_output, output)
f.assert_called_once_with(*args)
f.reset_mock()
kwargs = {'a': 100, 'b': 200}
output = await helpers.execute_optional(f, args, kwargs)
self.assertEqual(expected_output, output)
f.assert_called_once_with(*args, **kwargs)
f = AsyncMock(return_value=expected_output)
output = await helpers.execute_optional(f, args, kwargs)
self.assertEqual(expected_output, output)
f.assert_awaited_once_with(*args, **kwargs)
def test_star_function(self):
expected_output = 'bar'
f = MagicMock(return_value=expected_output)
a = (1, 2, 3)
stars = 0
output = helpers.star_function(f, a, stars)
self.assertEqual(expected_output, output)
f.assert_called_once_with(a)
f.reset_mock()
stars = 1
output = helpers.star_function(f, a, stars)
self.assertEqual(expected_output, output)
f.assert_called_once_with(*a)
f.reset_mock()
a = {'a': 1, 'b': 2}
stars = 2
output = helpers.star_function(f, a, stars)
self.assertEqual(expected_output, output)
f.assert_called_once_with(**a)
with self.assertRaises(ValueError):
helpers.star_function(f, a, 3)
with self.assertRaises(ValueError):
helpers.star_function(f, a, -1)
with self.assertRaises(ValueError):
helpers.star_function(f, a, 123456789)
async def test_join_queue(self):
mock_join = AsyncMock()
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()

View File

@ -1,7 +1,9 @@
import asyncio import asyncio
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.queues import Queue
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
@ -14,7 +16,12 @@ class TestException(Exception):
pass pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): class CommonTestCase(IsolatedAsyncioTestCase):
TEST_CLASS: Type[pool.BaseTaskPool] = pool.BaseTaskPool
TEST_POOL_SIZE: int = 420
TEST_POOL_NAME: str = 'test123'
task_pool: pool.BaseTaskPool
log_lvl: int log_lvl: int
@classmethod @classmethod
@ -26,39 +33,42 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl) pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def get_task_pool_init_params(self) -> dict:
self._pools = getattr(pool.BaseTaskPool, '_pools') return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
# These three methods are called during initialization, so we mock them by default during setup def setUp(self) -> None:
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') self._pools = self.TEST_CLASS._pools
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock) # These three methods are called during initialization, so we mock them by default during setup:
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
self.pool_size_patcher = patch.object(self.TEST_CLASS, 'pool_size', new_callable=PropertyMock)
self.dunder_str_patcher = patch.object(self.TEST_CLASS, '__str__')
self.mock__add_pool = self._add_pool_patcher.start() self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start() self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start() self.mock___str__ = self.dunder_str_patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123 self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar' self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters: self.task_pool = self.TEST_CLASS(**self.get_task_pool_init_params())
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None: def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools) self.TEST_CLASS._pools.clear()
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.__str___patcher.stop() self.dunder_str_patcher.stop()
class BaseTaskPoolTestCase(CommonTestCase):
def test__add_pool(self): def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools) self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool) output = pool.BaseTaskPool._add_pool(self.task_pool)
self.assertEqual(0, output) self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self): def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertTrue(self.task_pool._open) self.assertFalse(self.task_pool._locked)
self.assertEqual(0, self.task_pool._counter) self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
@ -66,27 +76,26 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(0, self.task_pool._num_cancelled) self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended) self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size) self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
def test___str__(self): def test___str__(self):
self.__str___patcher.stop() self.dunder_str_patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.TEST_POOL_NAME}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None) self.task_pool._name = None
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self): def test_pool_size(self):
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size self.task_pool._pool_size = self.TEST_POOL_SIZE
self.assertEqual(self.test_pool_size, self.task_pool.pool_size) self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.task_pool.pool_size = -1 self.task_pool.pool_size = -1
@ -94,9 +103,23 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool.pool_size = new_size = 69 self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size) self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self): def test_is_locked(self):
self.task_pool._open = FOO self.task_pool._locked = FOO
self.assertEqual(FOO, self.task_pool.is_open) self.assertEqual(FOO, self.task_pool.is_locked)
def test_lock(self):
assert not self.task_pool._locked
self.task_pool.lock()
self.assertTrue(self.task_pool._locked)
self.task_pool.lock()
self.assertTrue(self.task_pool._locked)
def test_unlock(self):
self.task_pool._locked = True
self.task_pool.unlock()
self.assertFalse(self.task_pool._locked)
self.task_pool.unlock()
self.assertFalse(self.task_pool._locked)
def test_num_running(self): def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz'] self.task_pool._running = ['foo', 'bar', 'baz']
@ -202,11 +225,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock) async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock,
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock, mock_create_task: MagicMock):
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
def reset_mocks() -> None: def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock() mock__task_name.reset_mock()
mock__task_wrapper.reset_mock() mock__task_wrapper.reset_mock()
mock_create_task.reset_mock() mock_create_task.reset_mock()
@ -217,31 +238,27 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._counter = count = 123 self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine): def check_nothing_changed() -> None:
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter) self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called() mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called() mock_create_task.assert_not_called()
reset_mocks() reset_mocks()
mock_is_open.return_value = ignore_closed = False with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
check_nothing_changed()
self.task_pool._locked = True
ignore_closed = False
mock_awaitable = mock_coroutine() mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._start_task(mock_awaitable, ignore_closed, await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable await mock_awaitable
self.assertEqual(count, self.task_pool._counter) check_nothing_changed()
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
ignore_closed = True ignore_closed = True
mock_awaitable = mock_coroutine() mock_awaitable = mock_coroutine()
@ -252,7 +269,6 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count]) self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value) self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count) mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@ -271,7 +287,6 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count) mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@ -336,47 +351,320 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self): async def test_gather(self):
mock_wait = AsyncMock()
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
test_exception = TestException() test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR) mock_running_func = AsyncMock(return_value=BAR)
mock_queue_join = AsyncMock()
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
self.task_pool._ended = ended = {123: mock_ended_func()} self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()} self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()} self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set() self.task_pool._interrupt_flag.set()
assert self.task_pool._open assert not self.task_pool._locked
with self.assertRaises(exceptions.PoolStillOpen): with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather() await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended) self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled) self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running) self.assertDictEqual(self.task_pool._running, running)
self.assertListEqual(self.task_pool._before_gathering, before_gather)
self.assertTrue(self.task_pool._interrupt_flag.is_set()) self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_not_awaited()
self.task_pool._open = False self.task_pool._locked = True
def check_assertions() -> None: def check_assertions(output) -> None:
self.assertListEqual([FOO, test_exception, BAR], output) self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT) self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_awaited_once_with()
output = await self.task_pool.gather(return_exceptions=True) check_assertions(await self.task_pool.gather(return_exceptions=True))
check_assertions()
mock_wait.reset_mock()
self.task_pool._before_gathering = [mock_queue_join()]
self.task_pool._ended = {123: mock_ended_func()} self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()} self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()} self.task_pool._running = {789: mock_running_func()}
output = await self.task_pool.gather(return_exceptions=True) check_assertions(await self.task_pool.gather(return_exceptions=True))
check_assertions()
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 12345
mock_awaitable = MagicMock()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
mock_func.reset_mock()
mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_apply_one')
async def test_apply(self, mock__apply_one: AsyncMock):
mock__apply_one.return_value = mock_id = 67890
mock_func, num = MagicMock(), 3
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id]
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
async def test__queue_producer(self):
mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put)
args = (FOO, BAR, 123)
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_has_awaits([call(arg) for arg in args])
mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial')
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
mock_partial: MagicMock):
mock_partial.return_value = queue_callback = 'not really'
mock_star_function.return_value = awaitable = 'totally an awaitable'
q, arg = Queue(), 420.69
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
@patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer')
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
task_id, mock_q = 420, MagicMock()
mock_func, stars = MagicMock(), 3
mock_wait = AsyncMock()
mock_flag = MagicMock(wait=mock_wait)
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb))
mock_wait.assert_awaited_once_with()
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
mock_create_task: MagicMock, mock_iter: MagicMock):
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
mock_join_queue.return_value = mock_join = 'awaitable'
mock_iter.return_value = args_iter = iter(args)
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(num_tasks, output_q.qsize())
for arg in args[:num_tasks]:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
for arg in args[num_tasks:]:
self.assertEqual(arg, next(args_iter))
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
self.task_pool._before_gathering.clear()
mock_join_queue.reset_mock()
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool, 'Event')
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
mock_event_cls: MagicMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_flag_set = MagicMock()
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._locked = False
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_flag_set.assert_not_called()
self.task_pool._locked = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)])
mock_flag_set.assert_called_once_with()
@patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock):
mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
class SimpleTaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.SimpleTaskPool
task_pool: pool.SimpleTaskPool
TEST_POOL_FUNC = AsyncMock(__name__=FOO)
TEST_POOL_ARGS = (FOO, BAR)
TEST_POOL_KWARGS = {'a': 1, 'b': 2}
TEST_POOL_END_CB = MagicMock()
TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | {
'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB,
}
def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
self.base_class_init = self.base_class_init_patcher.start()
super().setUp()
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
self.assertEqual(self.TEST_POOL_ARGS, self.task_pool._args)
self.assertEqual(self.TEST_POOL_KWARGS, self.task_pool._kwargs)
self.assertEqual(self.TEST_POOL_END_CB, self.task_pool._end_callback)
self.assertEqual(self.TEST_POOL_CANCEL_CB, self.task_pool._cancel_callback)
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
with self.assertRaises(exceptions.NotCoroutine):
pool.SimpleTaskPool(MagicMock())
def test_func_name(self):
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
@patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 99
self.task_pool._func = MagicMock(return_value=BAR)
output = await self.task_pool._start_one()
self.assertEqual(expected_output, output)
self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs)
mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
@patch.object(pool.SimpleTaskPool, '_start_one')
async def test_start(self, mock__start_one: AsyncMock):
mock__start_one.return_value = FOO
num = 5
output = await self.task_pool.start(num)
expected_output = num * [FOO]
self.assertListEqual(expected_output, output)
mock__start_one.assert_has_awaits(num * [call()])
@patch.object(pool.SimpleTaskPool, 'cancel')
def test_stop(self, mock_cancel: MagicMock):
num = 2
id1, id2, id3 = 5, 6, 7
self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR}
output = self.task_pool.stop(num)
expected_output = [id3, id2]
self.assertEqual(expected_output, output)
mock_cancel.assert_called_once_with(*expected_output)
mock_cancel.reset_mock()
num = 50
output = self.task_pool.stop(num)
expected_output = [id3, id2, id1]
self.assertEqual(expected_output, output)
mock_cancel.assert_called_once_with(*expected_output)
@patch.object(pool.SimpleTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.SimpleTaskPool, 'stop')
def test_stop_all(self, mock_stop: MagicMock, mock_num_running: MagicMock):
mock_num_running.return_value = num = 9876
mock_stop.return_value = expected_output = 'something'
output = self.task_pool.stop_all()
self.assertEqual(expected_output, output)
mock_num_running.assert_called_once_with()
mock_stop.assert_called_once_with(num)

View File

@ -2,18 +2,18 @@
## Minimal example for `SimpleTaskPool` ## Minimal example for `SimpleTaskPool`
The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `SimpleTaskPool` and starts/stops the tasks as desired, eventually awaiting them all. The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is. The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code ### Code
```python ```python
import logging import logging
import asyncio import asyncio
from asyncio_taskpool.pool import SimpleTaskPool from asyncio_taskpool.pool import SimpleTaskPool
logging.getLogger().setLevel(logging.NOTSET) logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler()) logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
@ -38,7 +38,7 @@ async def main() -> None:
await pool.start() # launches work task 3 await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line pool.lock() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool await pool.gather() # awaits all tasks, then flushes the pool
@ -60,7 +60,7 @@ did 1
did 1 did 1
did 1 did 1
did 0 did 0
SimpleTaskPool-0 is closed! SimpleTaskPool-0 is locked!
Cancelling SimpleTaskPool-0_Task-3 ... Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3 Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3 Ended SimpleTaskPool-0_Task-3
@ -77,6 +77,155 @@ did 4
did 4 did 4
``` ```
## Advanced example ## Advanced example for `TaskPool`
... This time, we want to start tasks from _different_ coroutine functions **and** with _different_ arguments. For this we need an instance of the more generalized `TaskPool` class.
As with the simple example, we need "worker" coroutine functions that can do something asynchronously, as well as a main coroutine function that sets up the pool, starts the tasks, and eventually awaits them.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
```python
import logging
import asyncio
from asyncio_taskpool.pool import TaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
async def work(start: int, stop: int, step: int = 1) -> None:
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
for i in range(start, stop, step):
await asyncio.sleep(1)
print("work with", i)
async def other_work(a: int, b: int) -> None:
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
for i in range(a, b):
await asyncio.sleep(0.5)
print("other_work with", i)
async def main() -> None:
# Initialize a new task pool instance and limit its size to 3 tasks.
pool = TaskPool(3)
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
print("Called `apply`")
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
# Let the tasks work for a bit.
await asyncio.sleep(1.5)
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
# positional arguments by using `starmap`, but have **no more than two of those** run concurrently.
# Since we set our pool size to 3, and already have two tasks working within the pool,
# only the first one of these will start immediately (and receive ID 2).
# The second one will start (with ID 3), only once there is room in the pool,
# which -- in this example -- will be the case after ID 2 ends;
# until then the `starmap` method call **will block**!
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
# **only** once there is room in the pool **and** no more than one of these last four tasks is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
print("Calling `starmap`...")
await pool.starmap(other_work, args_list, num_tasks=2)
print("`starmap` returned")
# Now we lock the pool, so that we can safely await all our tasks.
pool.lock()
# Finally, we block, until all tasks have ended.
print("Called `gather`")
await pool.gather()
print("Done.")
if __name__ == '__main__':
asyncio.run(main())
```
### Output
Additional comments for the output are provided with `<---` next to the output lines.
(Keep in mind that the logger and `print` asynchronously write to `stdout`.)
```
TaskPool-0 initialized
Started TaskPool-0_Task-0
Started TaskPool-0_Task-1
Called `apply`
work with 100
work with 100
Calling `starmap`... <--- notice that this blocks as expected
Started TaskPool-0_Task-2
work with 110
work with 110
other_work with 0
other_work with 1
work with 120
work with 120
other_work with 2
other_work with 3
work with 130
work with 130
other_work with 4
other_work with 5
work with 140
work with 140
other_work with 6
other_work with 7
work with 150
work with 150
other_work with 8
Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
TaskPool-0 is locked!
Started TaskPool-0_Task-3
other_work with 9
`starmap` returned
Called `gather`
work with 160
work with 160
other_work with 10
other_work with 11
work with 170
work with 170
other_work with 12
other_work with 13
work with 180
work with 180
other_work with 14
other_work with 15
Ended TaskPool-0_Task-0
Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start
Started TaskPool-0_Task-4
work with 190
work with 190
other_work with 16
other_work with 20
other_work with 17
other_work with 21
other_work with 18
other_work with 22
other_work with 19
Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start
Started TaskPool-0_Task-5
other_work with 23
other_work with 30
other_work with 24
other_work with 31
other_work with 25
other_work with 32
other_work with 26
other_work with 33
other_work with 27
other_work with 34
other_work with 28
other_work with 35
Ended TaskPool-0_Task-4
other_work with 29
other_work with 36
other_work with 37
other_work with 38
other_work with 39
Done.
Ended TaskPool-0_Task-5
```

View File

@ -48,12 +48,12 @@ async def main() -> None:
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue. # We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join() await q.join()
# Since we don't need any "work" done anymore, we can close our control server by cancelling the task. # Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
control_server_task.cancel() control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.stop_all() pool.stop_all()
pool.close() pool.lock()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values. # we just silently collect their exceptions along with their return values.