Compare commits

...

25 Commits

Author SHA1 Message Date
d48b20818f small improvement to the TaskPool._map args queue; added docstring to helper method 2022-02-08 12:23:16 +01:00
3c69740c8d factored out the queue setup of TaskPool._map 2022-02-08 12:09:21 +01:00
586023f722 fixed unittests; minor changes 2022-02-07 23:41:52 +01:00
16eda31648 reworked the _map method to make use of an arguments queue; added docstrings to the TaskPool methods 2022-02-07 23:17:39 +01:00
99ece436de docstrings 2022-02-07 19:26:20 +01:00
4ea815be65 refactored TaskPool._map; fixed TaskPool.apply; helper functions in separate module; new exception class 2022-02-07 17:31:43 +01:00
ac903d9be7 full test coverage for BaseTaskPool 2022-02-07 14:23:49 +01:00
e8e13406ea new interrupt flag for cancel_all to prevent _map() from starting new tasks 2022-02-06 16:32:42 +01:00
2d40f5707b docstrings for BaseTaskPool class 2022-02-06 15:00:47 +01:00
c0c9246b87 refactoring 2022-02-06 13:59:35 +01:00
ba0d5fca85 refactoring/renaming 2022-02-06 13:49:57 +01:00
b5eed608b5 refactoring; new flush method to clear memory of ended tasks 2022-02-06 13:42:34 +01:00
2f0b08edf0 implemented working pool size limit; adjusted tests and examples; small renaming 2022-02-06 13:08:39 +01:00
a68e61dfa7 small test improvement 2022-02-05 22:28:19 +01:00
9ec5359fd6 docstrings, refactorings, test adjustments 2022-02-05 22:26:02 +01:00
3a27040a54 first couple of unittests for BaseTaskPool 2022-02-05 19:34:52 +01:00
a154901bdf some convenience methods; minor renaming 2022-02-05 18:28:30 +01:00
3eae7d803f huge rework; two different task pool classes now 2022-02-05 18:02:32 +01:00
f45fef6497 usage example; pool close instead of gather 2022-02-04 19:23:09 +01:00
7020493d53 func command, minor changes 2022-02-04 18:00:23 +01:00
ed376b6f82 refactoring, additional command, improvements 2022-02-04 17:41:10 +01:00
b3b95877fb client cli and usage example 2022-02-04 16:25:44 +01:00
6a5c200ae6 working server and client 2022-02-04 16:25:09 +01:00
c9e0e2f255 first working draft 2022-02-04 11:07:02 +01:00
2e57447f5c meta adjustments 2022-02-03 20:10:26 +01:00
16 changed files with 1531 additions and 8 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@
/dist/
# Python cache:
__pycache__/
# Testing:
.coverage

View File

@ -4,7 +4,7 @@ Dynamically manage pools of asyncio tasks
## Usage
...
See [USAGE.md](usage/USAGE.md)
## Installation
@ -12,7 +12,7 @@ Dynamically manage pools of asyncio tasks
## Dependencies
Python Version ..., OS ...
Python Version 3.8+, tested on Linux
## Building from source

View File

@ -1,8 +1,8 @@
[metadata]
name = asyncio-taskpool
version = 0.0.1
author = Daniil
author_email = mail@placeholder123.to
version = 0.1.4
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks
long_description = file: README.md
long_description_content_type = text/markdown
@ -17,9 +17,7 @@ classifiers =
package_dir =
= src
packages = find:
python_requires = >=3
install_requires =
...
python_requires = >=3.8
[options.extras_require]
dev =

View File

@ -0,0 +1,2 @@
from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer

View File

@ -0,0 +1,46 @@
import sys
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Dict, Any
from .client import ControlClient, UnixControlClient
from .constants import PACKAGE_NAME
from .pool import TaskPool
from .server import ControlServer
CONN_TYPE = 'conn_type'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
def parse_cli() -> Dict[str, Any]:
parser = ArgumentParser(
prog=PACKAGE_NAME,
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
)
return vars(parser.parse_args())
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class
client = UnixControlClient(path=kwargs[SOCKET_PATH])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
if __name__ == '__main__':
run(main())

View File

@ -0,0 +1,63 @@
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from asyncio_taskpool import constants
from asyncio_taskpool.types import ClientConnT
class ControlClient(ABC):
@abstractmethod
async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError
def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try:
msg = input("> ").strip().lower()
except EOFError:
msg = constants.CLIENT_EXIT
except KeyboardInterrupt:
print()
return
if msg == constants.CLIENT_EXIT:
writer.close()
self._connected = False
return
try:
writer.write(msg.encode())
await writer.drain()
except ConnectionError as e:
self._connected = False
print(e, file=sys.stderr)
return
print((await reader.read(constants.MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
self._connected = True
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
while self._connected:
await self._interact(reader, writer)
print("Disconnected from control server.")
class UnixControlClient(ControlClient):
def __init__(self, **conn_kwargs) -> None:
self._socket_path = Path(conn_kwargs.pop('path'))
super().__init__(**conn_kwargs)
async def open_connection(self, **kwargs) -> ClientConnT:
try:
return await open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr)
return None, None

View File

@ -0,0 +1,8 @@
PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all'
CMD_SIZE = 'size'
CMD_FUNC = 'func'
CLIENT_EXIT = 'exit'

View File

@ -0,0 +1,30 @@
class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
pass
class TaskEnded(PoolException):
pass
class AlreadyCancelled(TaskEnded):
pass
class AlreadyEnded(TaskEnded):
pass
class InvalidTaskID(PoolException):
pass
class PoolStillOpen(PoolException):
pass
class NotCoroutine(PoolException):
pass

View File

@ -0,0 +1,29 @@
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
await q.join()

View File

@ -0,0 +1,673 @@
import logging
from asyncio import gather
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore
from asyncio.queues import Queue, QueueEmpty
from asyncio.tasks import Task, create_task
from functools import partial
from math import inf
from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions
from .helpers import execute_optional, star_function, join_queue
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
log = logging.getLogger(__name__)
class BaseTaskPool:
"""The base class for task pools. Not intended to be used directly."""
_pools: List['BaseTaskPool'] = []
@classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
cls._pools.append(pool)
return len(cls._pools) - 1
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
self._ended: Dict[int, Task] = {}
self._num_cancelled: int = 0
self._num_ended: int = 0
self._idx: int = self._add_pool(self)
self._name: str = name
self._before_gathering: List[Awaitable] = []
self._interrupt_flag: Event = Event()
log.debug("%s initialized", str(self))
def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def pool_size(self) -> int:
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
return self._pool_size
@pool_size.setter
def pool_size(self, value: int) -> None:
"""
Sets the maximum number of concurrently running tasks in the pool.
Args:
value:
A non-negative integer.
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
Raises:
`ValueError` if `value` is less than 0.
"""
if value < 0:
raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value
self._pool_size = value
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open
@property
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's `end_callback` is fired, it is no longer considered to be running.
"""
return len(self._running)
@property
def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
"""
return self._num_cancelled
@property
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's `end_callback` is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end.
"""
return self._num_ended
@property
def num_finished(self) -> int:
"""
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
"""
return self._num_ended - self._num_cancelled + len(self._cancelled)
@property
def is_full(self) -> bool:
"""
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
"""
return self._enough_room.locked()
# TODO: Consider adding task group names
def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}'
async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool being cancelled.
Required for keeping track of running/cancelled tasks and proper logging.
Args:
task_id:
The ID of the task that has been cancelled.
custom_callback (optional):
A callback to execute after cancellation of the task.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
log.debug("Cancelling %s ...", self._task_name(task_id))
self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id))
await execute_optional(custom_callback, args=(task_id,))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool ending its work.
Required for keeping track of running/cancelled/ended tasks and proper logging.
Also releases room in the task pool for potentially waiting tasks.
Args:
task_id:
The ID of the task that has reached its end.
custom_callback (optional):
A callback to execute after the task has ended.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
try:
self._ended[task_id] = self._running.pop(task_id)
except KeyError:
self._ended[task_id] = self._cancelled.pop(task_id)
self._num_ended += 1
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await execute_optional(custom_callback, args=(task_id,))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
"""
Universal wrapper around every task to be run in the pool.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
task_id:
The ID of the newly created task.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument.
"""
log.info("Started %s", self._task_name(task_id))
try:
return await awaitable
except CancelledError:
await self._task_cancellation(task_id, custom_callback=cancel_callback)
finally:
await self._task_ending(task_id, custom_callback=end_callback)
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
"""
Starts a coroutine as a new task in the pool.
This method blocks, **only if** the pool is full.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
ignore_closed (optional):
If `True`, even if the pool is closed, the task will still be started.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the task's ID as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.NotCoroutine` if `awaitable` is not a coroutine.
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
"""
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire()
task_id = self._counter
self._counter += 1
try:
self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id)
)
except Exception as e:
self._enough_room.release()
raise e
return task_id
def _get_running_task(self, task_id: int) -> Task:
"""
Gets a running task by its task ID.
Args:
task_id: The ID of a task still running within the pool.
Raises:
`asyncio_taskpool.exceptions.AlreadyCancelled` if the task with `task_id` has been (recently) cancelled.
`asyncio_taskpool.exceptions.AlreadyEnded` if the task with `task_id` has ended (recently).
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
"""
try:
return self._running[task_id]
except KeyError:
if self._cancelled.get(task_id):
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
if self._ended.get(task_id):
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
def cancel(self, *task_ids: int, msg: str = None) -> None:
"""
Cancels the tasks with the specified IDs.
Each task ID must belong to a task still running within the pool. Otherwise one of the following exceptions will
be raised:
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
Args:
task_ids:
Arbitrary number of integers. Each must be an ID of a task still running within the pool.
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids]
for task in tasks:
task.cancel(msg=msg)
def cancel_all(self, msg: str = None) -> None:
"""
Cancels all tasks still running within the pool.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, those already running will be cancelled, while the following will **never even start**.
Args:
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
log.warning("%s cancelling all tasks!", str(self))
self._interrupt_flag.set()
for task in self._running.values():
task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
"""
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended.clear()
self._cancelled.clear()
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results
def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
blocking this method. Otherwise it will wait until they all have started.
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
callbacks registered for a task blocks.
Args:
return_exceptions (optional): Passed directly into `gather`.
Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
"""
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await gather(*self._before_gathering)
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
self._ended.clear()
self._cancelled.clear()
self._running.clear()
self._before_gathering.clear()
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results
class TaskPool(BaseTaskPool):
"""
General task pool class.
Attempts to somewhat emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
Tasks in the pool can all belong to the same coroutine function,
but they can also come from any number of different and unrelated coroutine functions.
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
Adding tasks blocks **only if** the pool is full at that moment.
"""
async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
This method blocks, **only if** the pool is full.
Args:
func:
The coroutine function to be run as a task within the task pool.
args (optional):
The positional arguments to pass into the function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the task's ID as its only positional argument.
Returns:
The newly spawned task's ID within the pool.
"""
if kwargs is None:
kwargs = {}
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
"""
Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool.
Each coroutine looks like `func(*args, **kwargs)`.
This method blocks, **only if** there is not enough room in the pool for the desired number of new tasks.
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
args (optional):
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Returns:
The newly spawned tasks' IDs within the pool as a list of integers.
Raises:
`NotCoroutine` if `func` is not a coroutine function.
`PoolIsClosed` if the pool has been closed already.
"""
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
assert isinstance(ids, list)
return ids
async def _queue_producer(self, q: Queue, args_iter: Iterator[Any]) -> None:
"""
Keeps the arguments queue from `_map()` full as long as the iterator has elements.
If the `_interrupt_flag` gets set, the loop ends prematurely.
Args:
q:
The queue of function arguments to consume for starting the next task.
args_iter:
The iterator of function arguments to put into the queue.
"""
for arg in args_iter:
if self._interrupt_flag.is_set():
break
await q.put(arg) # This blocks as long as the queue is full.
async def _queue_consumer(self, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around the `_start_task()` taking the next element from the arguments queue set up in `_map()`.
Partially constructs the `_queue_callback` function with the same arguments.
Args:
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional):
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument.
"""
try:
arg = q.get_nowait()
except QueueEmpty:
return
try:
await self._start_task(
star_function(func, arg, arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._queue_callback, self, q=q, func=func, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
finally:
q.task_done()
async def _queue_callback(self, task_id: int, q: Queue, func: CoroutineFunc, arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Wrapper around an end callback function passed into the `_map()` method.
Triggers the next `_queue_consumer` with the same arguments.
Args:
task_id:
The ID of the ending task.
q:
The queue of function arguments to consume for starting the next task.
func:
The coroutine function to use for spawning the tasks within the task pool.
arg_stars (optional):
Whether or not to unpack an element from `q` using stars; must be 0, 1, or 2.
end_callback (optional):
The actual callback specified to execute after the task (and the next one) has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the `task_id` as its only positional argument.
"""
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,))
def _fill_args_queue(self, q: Queue, args_iter: ArgsT, num_tasks: int) -> None:
"""
Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to the arguments queue `q`.
If the iterable contains less than `num_tasks` elements, nothing else happens.
Otherwise the `_queue_producer` is started with the arguments queue and and iterator of the remaining arguments.
Args:
q:
The (empty) new `asyncio.Queue` to hold the function arguments passed as `args_iter`.
args_iter:
The iterable of function arguments passed into `_map()` to use for creating the new tasks.
num_tasks:
The maximum number of the new tasks to run concurrently that was passed into `_map()`.
"""
args_iter = iter(args_iter)
try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks):
q.put_nowait(next(args_iter))
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
return
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting
# the next item in it.
create_task(self._queue_producer(q, args_iter))
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
TODO: If task groups are implemented, consider adding all tasks from one call of this method to the same group
and referring to "group size" rather than chunk/batch size.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being an element from the iterable.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
The queue's `join()` method is added to the pool's `_before_gathering` list.
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
args_iter:
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
arg_stars (optional):
Whether or not to unpack an element from `args_iter` using stars; must be 0, 1, or 2.
num_tasks (optional):
The maximum number of the new tasks to run concurrently.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed.
"""
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
self._fill_args_queue(args_queue, args_iter, num_tasks)
for _ in range(args_queue.qsize()):
# This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func,
arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from a supplied iterable and runs them as new tasks in the pool in batches.
Each coroutine looks like `func(arg)`, `arg` being an element from the iterable.
Once the first batch of tasks has started to run, this method returns.
As soon as on of them finishes, it triggers the start of a new task (assuming there is room in the pool)
consuming the next element from the arguments iterable.
If the size of the pool never imposes a limit, this ensures that there is almost continuously the desired number
of tasks from this call concurrently running within the pool.
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
Args:
func:
The coroutine function to use for spawning the new tasks within the task pool.
arg_iter:
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
num_tasks (optional):
The maximum number of the new tasks to run concurrently.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
Raises:
`PoolIsClosed` if the pool has been closed.
`NotCoroutine` if `func` is not a coroutine function.
"""
await self._map(func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked as
positional arguments to the function.
Each coroutine then looks like `func(*arg)`, `arg` being an element from `args_iter`.
"""
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
"""
Like `map()` except that the elements of `kwargs_iter` are expected to be iterables themselves to be unpacked as
keyword-arguments to the function.
Each coroutine then looks like `func(**arg)`, `arg` being an element from `kwargs_iter`.
"""
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None:
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func
self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback
super().__init__(name=name)
@property
def func_name(self) -> str:
return self._func.__name__
@property
def size(self) -> int:
return self.num_running
async def _start_one(self) -> int:
return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
async def start(self, num: int = 1) -> List[int]:
return [await self._start_one() for _ in range(num)]
def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size)
ids = []
for i, task_id in enumerate(reversed(self._running)):
if i >= num:
break
ids.append(task_id)
self.cancel(*ids)
return ids
def stop_all(self) -> List[int]:
return self.stop(self.size)

View File

@ -0,0 +1,130 @@
import logging
from abc import ABC, abstractmethod
from asyncio import AbstractServer
from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Tuple, Union, Optional
from . import constants
from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient
log = logging.getLogger(__name__)
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
cmd = msg.strip().split(' ', 1)
if len(cmd) > 1:
try:
return cmd[0], int(cmd[1])
except ValueError:
return None, None
return cmd[0], None
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient
@abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
raise NotImplementedError
@abstractmethod
def final_callback(self) -> None:
raise NotImplementedError
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._pool: SimpleTaskPool = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(await self._pool.start(num)).encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
# the requested number may be greater than the total number of running tasks
writer.write(str(self._pool.stop(num)).encode())
def _stop_all_tasks(self, writer: StreamWriter) -> None:
log.debug("%s requests stopping all tasks", self.client_class.__name__)
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests pool size", self.client_class.__name__)
writer.write(str(self._pool.size).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
writer.write(self._pool.func_name.encode())
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
while self._server.is_serving():
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self.client_class.__name__)
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_SIZE:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)
else:
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
writer.write(b"Invalid command!")
await writer.drain()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__)
writer.write(str(self._pool).encode())
await writer.drain()
await self._listen(reader, writer)
async def _serve_forever(self) -> None:
try:
async with self._server:
await self._server.serve_forever()
except CancelledError:
log.debug("%s stopped", self.__class__.__name__)
finally:
self.final_callback()
async def serve_forever(self) -> Task:
log.debug("Starting %s...", self.__class__.__name__)
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
return create_task(self._serve_forever())
class UnixControlServer(ControlServer):
client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs)
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path))
return srv
def final_callback(self) -> None:
self._socket_path.unlink()
log.debug("Removed socket '%s'", str(self._socket_path))

View File

@ -0,0 +1,16 @@
from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable
CancelCallbackT = Callable
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

379
tests/test_pool.py Normal file
View File

@ -0,0 +1,379 @@
import asyncio
from asyncio.exceptions import CancelledError
from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from asyncio_taskpool import pool, exceptions
EMPTY_LIST, EMPTY_DICT = [], {}
FOO, BAR = 'foo', 'bar'
class TestException(Exception):
pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools')
# These three methods are called during initialization, so we mock them by default during setup
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters:
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools)
self._add_pool_patcher.stop()
self.pool_size_patcher.stop()
self.__str___patcher.stop()
def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool)
self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertTrue(self.task_pool._open)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with()
def test___str__(self):
self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None)
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self):
self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self):
self.task_pool._open = FOO
self.assertEqual(FOO, self.task_pool.is_open)
def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz']
self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self):
self.task_pool._num_cancelled = 3
self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self):
self.task_pool._num_ended = 3
self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self):
self.task_pool._num_cancelled = cancelled = 69
self.task_pool._num_ended = ended = 420
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test__task_name(self):
i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock()
mock_execute_optional.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_wrapper(self, mock__task_name: MagicMock,
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
task_id = 42
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
# Cancelled during execution:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertIsNone(output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
mock_coroutine_func.reset_mock(side_effect=True)
mock__task_name.reset_mock()
mock__task_cancellation.reset_mock()
mock__task_ending.reset_mock()
# Not cancelled:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(FOO, output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_not_awaited()
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
mock_is_open.return_value = ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
reset_mocks()
self.task_pool._counter = count
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
del self.task_pool._ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._running = {1: mock_task1, 2: mock_task2}
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_task1.cancel.assert_called_once_with(msg=FOO)
mock_task2.cancel.assert_called_once_with(msg=FOO)
async def test_flush(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._interrupt_flag.set()
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
mock_queue_join = AsyncMock()
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert self.task_pool._open
with self.assertRaises(exceptions.PoolStillOpen):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running)
self.assertListEqual(self.task_pool._before_gathering, before_gather)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
self.task_pool._open = False
def check_assertions(output) -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
check_assertions(await self.task_pool.gather(return_exceptions=True))
self.task_pool._before_gathering = [mock_queue_join()]
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
check_assertions(await self.task_pool.gather(return_exceptions=True))

82
usage/USAGE.md Normal file
View File

@ -0,0 +1,82 @@
# Using `asyncio-taskpool`
## Minimal example for `SimpleTaskPool`
The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `SimpleTaskPool` and starts/stops the tasks as desired, eventually awaiting them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
```python
import logging
import asyncio
from asyncio_taskpool.pool import SimpleTaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
async def work(n: int) -> None:
"""
Pseudo-worker function.
Counts up to an integer with a second of sleep before each iteration.
In a real-world use case, a worker function should probably have access
to some synchronisation primitive or shared resource to distribute work
between an arbitrary number of workers.
"""
for i in range(n):
await asyncio.sleep(1)
print("did", i)
async def main() -> None:
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool
if __name__ == '__main__':
asyncio.run(main())
```
### Output
```
SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0
Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2
did 0
did 0
did 0
Started SimpleTaskPool-0_Task-3
did 1
did 1
did 1
did 0
SimpleTaskPool-0 is closed!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2
did 2
did 2
did 3
did 3
Ended SimpleTaskPool-0_Task-0
Ended SimpleTaskPool-0_Task-1
did 4
did 4
```
## Advanced example
...

65
usage/example_server.py Normal file
View File

@ -0,0 +1,65 @@
import asyncio
import logging
from asyncio_taskpool import SimpleTaskPool, UnixControlServer
from asyncio_taskpool.constants import PACKAGE_NAME
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
async def work(item: int) -> None:
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
await asyncio.sleep(1)
print("worked on", item)
async def worker(q: asyncio.Queue) -> None:
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
while True:
# We want to block here, until we can get the next item from the queue.
item = await q.get()
# Since we want a nice cleanup upon cancellation, we put the "work" to be done in a `try:` block.
try:
await work(item)
except asyncio.CancelledError:
# If the task gets cancelled before our current "work" item is finished, we put it back into the queue
# because a worker must assume that some other worker can and will eventually finish the work on that item.
q.put_nowait(item)
# This takes us out of the loop. To enable cleanup we must re-raise the exception.
raise
finally:
# Since putting an item into the queue (even if it has just been taken out), increments the internal
# `._unfinished_tasks` counter in the queue, we must ensure that it is decremented before we end the
# iteration or leave the loop. Otherwise, the queue's `.join()` will block indefinitely.
q.task_done()
async def main() -> None:
# First, we set up a queue of items that our workers can "work" on.
q = asyncio.Queue()
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()
# Since we don't need any "work" done anymore, we can close our control server by cancelling the task.
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.
pool.stop_all()
pool.close()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values.
await pool.gather(return_exceptions=True)
await control_server_task
if __name__ == '__main__':
asyncio.run(main())