Compare commits

..

9 Commits

9 changed files with 576 additions and 169 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@
/dist/ /dist/
# Python cache: # Python cache:
__pycache__/ __pycache__/
# Testing:
.coverage

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.0.1 version = 0.0.2
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -14,7 +14,7 @@ class AlreadyCancelled(TaskEnded):
pass pass
class AlreadyFinished(TaskEnded): class AlreadyEnded(TaskEnded):
pass pass

View File

@ -2,164 +2,348 @@ import logging
from asyncio import gather from asyncio import gather
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Event from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from math import inf from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from . import exceptions from . import exceptions
from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class BaseTaskPool: class BaseTaskPool:
"""The base class for task pools. Not intended to be used directly."""
_pools: List['BaseTaskPool'] = [] _pools: List['BaseTaskPool'] = []
@classmethod @classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int: def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
cls._pools.append(pool) cls._pools.append(pool)
return len(cls._pools) - 1 return len(cls._pools) - 1
# TODO: Make use of `max_running` def __init__(self, pool_size: int = inf, name: str = None) -> None:
def __init__(self, max_running: int = inf, name: str = None) -> None: """Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._max_running: int = max_running self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True self._open: bool = True
self._counter: int = 0 self._counter: int = 0
self._running: Dict[int, Task] = {} self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {} self._cancelled: Dict[int, Task] = {}
self._ended: Dict[int, Task] = {} self._ended: Dict[int, Task] = {}
self._num_cancelled: int = 0
self._num_ended: int = 0
self._idx: int = self._add_pool(self) self._idx: int = self._add_pool(self)
self._name: str = name self._name: str = name
self._all_tasks_known_flag: Event = Event() self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set() self._all_tasks_known_flag.set()
self._more_allowed_flag: Event = Event() self._interrupt_flag: Event = Event()
self._check_more_allowed()
log.debug("%s initialized", str(self)) log.debug("%s initialized", str(self))
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}' return f'{self.__class__.__name__}-{self._name or self._idx}'
@property @property
def max_running(self) -> int: def pool_size(self) -> int:
return self._max_running """Returns the maximum number of concurrently running tasks currently set in the pool."""
return self._pool_size
@pool_size.setter
def pool_size(self, value: int) -> None:
"""
Sets the maximum number of concurrently running tasks in the pool.
Args:
value:
A non-negative integer.
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
Raises:
`ValueError` if `value` is less than 0.
"""
if value < 0:
raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value
self._pool_size = value
@property @property
def is_open(self) -> bool: def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open return self._open
@property @property
def num_running(self) -> int: def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's `end_callback` is fired, it is no longer considered to be running.
"""
return len(self._running) return len(self._running)
@property @property
def num_cancelled(self) -> int: def num_cancelled(self) -> int:
return len(self._cancelled) """
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
"""
return self._num_cancelled
@property @property
def num_ended(self) -> int: def num_ended(self) -> int:
return len(self._ended) """
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's `end_callback` is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end.
"""
return self._num_ended
@property @property
def num_finished(self) -> int: def num_finished(self) -> int:
return self.num_ended - self.num_cancelled """
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
"""
return self._num_ended - self._num_cancelled + len(self._cancelled)
@property @property
def is_full(self) -> bool: def is_full(self) -> bool:
return not self._more_allowed_flag.is_set() """
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
def _check_more_allowed(self) -> None: When the pool is full, any call to start a new task within it will block.
if self.is_full and self.num_running < self.max_running: """
self._more_allowed_flag.set() return self._enough_room.locked()
elif not self.is_full and self.num_running >= self.max_running:
self._more_allowed_flag.clear()
# TODO: Consider adding task group names
def _task_name(self, task_id: int) -> str: def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}' return f'{self}_Task-{task_id}'
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool being cancelled.
Required for keeping track of running/cancelled tasks and proper logging.
Args:
task_id:
The ID of the task that has been cancelled.
custom_callback (optional):
A callback to execute after cancellation of the task.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
log.debug("Cancelling %s ...", self._task_name(task_id)) log.debug("Cancelling %s ...", self._task_name(task_id))
task = self._running.pop(task_id) self._cancelled[task_id] = self._running.pop(task_id)
assert task is not None self._num_cancelled += 1
self._cancelled[task_id] = task
await _execute_function(custom_callback, args=(task_id, ))
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None:
task = self._running.pop(task_id, None)
if task is None:
task = self._cancelled[task_id]
self._ended[task_id] = task
await _execute_function(custom_callback, args=(task_id, )) await _execute_function(custom_callback, args=(task_id, ))
self._check_more_allowed()
log.info("Ended %s", self._task_name(task_id))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None, async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
"""
Universal callback to be run upon any task in the pool ending its work.
Required for keeping track of running/cancelled/ended tasks and proper logging.
Also releases room in the task pool for potentially waiting tasks.
Args:
task_id:
The ID of the task that has reached its end.
custom_callback (optional):
A callback to execute after the task has ended.
It is run at the end of this function with the `task_id` as its only positional argument.
"""
try:
self._ended[task_id] = self._running.pop(task_id)
except KeyError:
self._ended[task_id] = self._cancelled.pop(task_id)
self._num_ended += 1
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any: cancel_callback: CancelCallbackT = None) -> Any:
"""
Universal wrapper around every task to be run in the pool.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
task_id:
The ID of the newly created task.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument.
"""
log.info("Started %s", self._task_name(task_id)) log.info("Started %s", self._task_name(task_id))
try: try:
return await awaitable return await awaitable
except CancelledError: except CancelledError:
await self._cancel_task(task_id, custom_callback=cancel_callback) await self._task_cancellation(task_id, custom_callback=cancel_callback)
finally: finally:
await self._end_task(task_id, custom_callback=final_callback) await self._task_ending(task_id, custom_callback=end_callback)
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None, async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int: cancel_callback: CancelCallbackT = None) -> int:
if not (self._open or ignore_closed): """
Starts a coroutine as a new task in the pool.
This method blocks, **only if** the pool is full.
Returns/raises whatever the wrapped coroutine does.
Args:
awaitable:
The actual coroutine to be run within the task pool.
ignore_closed (optional):
If `True`, even if the pool is closed, the task will still be started.
end_callback (optional):
A callback to execute after the task has ended.
It is run with the `task_id` as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
It is run with the `task_id` as its only positional argument.
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
"""
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire()
task_id = self._counter task_id = self._counter
self._counter += 1 self._counter += 1
self._running[task_id] = create_task( try:
self._task_wrapper(awaitable, task_id, final_callback, cancel_callback), self._running[task_id] = create_task(
name=self._task_name(task_id) self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
) name=self._task_name(task_id)
self._check_more_allowed() )
except Exception as e:
self._enough_room.release()
raise e
return task_id return task_id
def _cancel_one(self, task_id: int, msg: str = None) -> None: def _get_running_task(self, task_id: int) -> Task:
"""
Gets a running task by its task ID.
Args:
task_id: The ID of a task still running within the pool.
Raises:
`asyncio_taskpool.exceptions.AlreadyCancelled` if the task with `task_id` has been (recently) cancelled.
`asyncio_taskpool.exceptions.AlreadyEnded` if the task with `task_id` has ended (recently).
`asyncio_taskpool.exceptions.InvalidTaskID` if no task with `task_id` is known to the pool.
"""
try: try:
task = self._running[task_id] return self._running[task_id]
except KeyError: except KeyError:
if self._cancelled.get(task_id): if self._cancelled.get(task_id):
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled") raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
if self._ended.get(task_id): if self._ended.get(task_id):
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
task.cancel(msg=msg)
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
for task_id in task_ids: """
self._cancel_one(task_id, msg=msg) Cancels the tasks with the specified IDs.
Each task ID must belong to a task still running within the pool. Otherwise one of the following exceptions will
be raised:
- `AlreadyCancelled` if one of the `task_ids` belongs to a task that has been (recently) cancelled.
- `AlreadyEnded` if one of the `task_ids` belongs to a task that has ended (recently).
- `InvalidTaskID` if any of the `task_ids` is not known to the pool.
Note that once a pool has been flushed, any IDs of tasks that have ended previously will be forgotten.
Args:
task_ids:
Arbitrary number of integers. Each must be an ID of a task still running within the pool.
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids]
for task in tasks:
task.cancel(msg=msg)
def cancel_all(self, msg: str = None) -> None: def cancel_all(self, msg: str = None) -> None:
"""
Cancels all tasks still running within the pool.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, those already running will be cancelled, while the following will **never even start**.
Args:
msg (optional):
Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
"""
log.warning("%s cancelling all tasks!", str(self))
self._interrupt_flag.set()
for task in self._running.values(): for task in self._running.values():
task.cancel(msg=msg) task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, returns their results, and forgets the tasks.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
"""
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results
def close(self) -> None: def close(self) -> None:
"""Disallows any more tasks to be started in the pool."""
self._open = False self._open = False
log.info("%s is closed!", str(self)) log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False): async def gather(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks from the pool, returns their results, and forgets the tasks.
The `close()` method must have been called prior to this.
Note that there may be an unknown number of coroutine functions "queued" to be run as tasks.
This can happen, if for example the `TaskPool.map` method was called with `num_tasks` set to a number smaller
than the number of arguments from `args_iter`.
In this case, calling `cancel_all()` prior to this, will prevent those tasks from starting and potentially
blocking this method. Otherwise it will wait until they all have started.
This method may also block, if any task blocks while catching a `asyncio.CancelledError` or if any of the
callbacks registered for a task blocks.
Args:
return_exceptions (optional): Passed directly into `gather`.
Raises:
`asyncio_taskpool.exceptions.PoolStillOpen` if the pool has not been closed yet.
"""
if self._open: if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait() await self._all_tasks_known_flag.wait()
results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
self._running = self._cancelled = self._ended = {} return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear()
return results return results
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):
def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback) return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num)) return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))
@staticmethod @staticmethod
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
@ -175,54 +359,54 @@ class TaskPool(BaseTaskPool):
return func(**arg) return func(**arg)
raise ValueError raise ValueError
def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if self._all_tasks_known_flag.is_set(): if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear() self._all_tasks_known_flag.clear()
args_iter = iter(args_iter) args_iter = iter(args_iter)
def _start_next_coroutine() -> bool: async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars) cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None: if cor is None or self._interrupt_flag.is_set():
self._all_tasks_known_flag.set() self._all_tasks_known_flag.set()
return True return True
self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback) await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False return False
async def _start_next(task_id: int) -> None: async def _start_next(task_id: int) -> None:
await _execute_function(final_callback, args=(task_id, )) await _start_next_coroutine()
_start_next_coroutine() await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks): for _ in range(num_tasks):
reached_end = _start_next_coroutine() reached_end = await _start_next_coroutine()
if reached_end: if reached_end:
break break
def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks, await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback) end_callback=end_callback, cancel_callback=cancel_callback)
class SimpleTaskPool(BaseTaskPool): class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: name: str = None) -> None:
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
self._args: ArgsT = args self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._final_callback: FinalCallbackT = final_callback self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback self._cancel_callback: CancelCallbackT = cancel_callback
super().__init__(name=name) super().__init__(name=name)
@ -234,12 +418,12 @@ class SimpleTaskPool(BaseTaskPool):
def size(self) -> int: def size(self) -> int:
return self.num_running return self.num_running
def _start_one(self) -> int: async def _start_one(self) -> int:
return self._start_task(self._func(*self._args, **self._kwargs), return await self._start_task(self._func(*self._args, **self._kwargs),
final_callback=self._final_callback, cancel_callback=self._cancel_callback) end_callback=self._end_callback, cancel_callback=self._cancel_callback)
def start(self, num: int = 1) -> List[int]: async def start(self, num: int = 1) -> List[int]:
return [self._start_one() for _ in range(num)] return [await self._start_one() for _ in range(num)]
def stop(self, num: int = 1) -> List[int]: def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size) num = min(num, self.size)

View File

@ -45,11 +45,11 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
self._server_kwargs = server_kwargs self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None self._server: Optional[AbstractServer] = None
def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None: if num is None:
num = 1 num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(self._pool.start(num)).encode()) writer.write(str(await self._pool.start(num)).encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None: if num is None:
@ -78,7 +78,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
break break
cmd, arg = get_cmd_arg(msg) cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START: if cmd == constants.CMD_START:
self._start_tasks(writer, arg) await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP: elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg) self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL: elif cmd == constants.CMD_STOP_ALL:

View File

@ -5,7 +5,7 @@ from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
FinalCallbackT = Callable EndCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

View File

@ -1,35 +1,52 @@
import asyncio import asyncio
from unittest import TestCase from asyncio.exceptions import CancelledError
from unittest.mock import MagicMock, PropertyMock, patch, call from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from asyncio_taskpool import pool from asyncio_taskpool import pool, exceptions
EMPTY_LIST, EMPTY_DICT = [], {} EMPTY_LIST, EMPTY_DICT = [], {}
FOO, BAR = 'foo', 'bar'
class BaseTaskPoolTestCase(TestCase): class TestException(Exception):
pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools') self._pools = getattr(pool.BaseTaskPool, '_pools')
# These three methods are called during initialization, so we mock them by default during setup # These three methods are called during initialization, so we mock them by default during setup
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool') self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
self._check_more_allowed_patcher = patch.object(pool.BaseTaskPool, '_check_more_allowed') self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__') self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
self.mock__add_pool = self._add_pool_patcher.start() self.mock__add_pool = self._add_pool_patcher.start()
self.mock__check_more_allowed = self._check_more_allowed_patcher.start() self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start() self.mock___str__ = self.__str___patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123 self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar' self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters: # Test pool parameters:
self.mock_pool_params = {'max_running': 420, 'name': 'test123'} self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(**self.mock_pool_params) self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None: def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools) setattr(pool.TaskPool, '_pools', self._pools)
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
self._check_more_allowed_patcher.stop() self.pool_size_patcher.stop()
self.__str___patcher.stop() self.__str___patcher.stop()
def test__add_pool(self): def test__add_pool(self):
@ -40,105 +57,309 @@ class BaseTaskPoolTestCase(TestCase):
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools')) self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self): def test_init(self):
for key, value in self.mock_pool_params.items(): self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertEqual(value, getattr(self.task_pool, f'_{key}')) self.assertTrue(self.task_pool._open)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running) self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set()) self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._more_allowed_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock__check_more_allowed.assert_called_once_with() self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
def test___str__(self): def test___str__(self):
self.__str___patcher.stop() self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.mock_pool_params["name"]}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None) setattr(self.task_pool, '_name', None)
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}' expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool)) self.assertEqual(expected_str, str(self.task_pool))
def test_max_running(self): def test_pool_size(self):
self.task_pool._max_running = foo = 'foo' self.pool_size_patcher.stop()
self.assertEqual(foo, self.task_pool.max_running) self.task_pool._pool_size = self.test_pool_size
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self): def test_is_open(self):
self.task_pool._open = foo = 'foo' self.task_pool._open = FOO
self.assertEqual(foo, self.task_pool.is_open) self.assertEqual(FOO, self.task_pool.is_open)
def test_num_running(self): def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz'] self.task_pool._running = ['foo', 'bar', 'baz']
self.assertEqual(3, self.task_pool.num_running) self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self): def test_num_cancelled(self):
self.task_pool._cancelled = ['foo', 'bar', 'baz'] self.task_pool._num_cancelled = 3
self.assertEqual(3, self.task_pool.num_cancelled) self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self): def test_num_ended(self):
self.task_pool._ended = ['foo', 'bar', 'baz'] self.task_pool._num_ended = 3
self.assertEqual(3, self.task_pool.num_ended) self.assertEqual(3, self.task_pool.num_ended)
@patch.object(pool.BaseTaskPool, 'num_ended', new_callable=PropertyMock) def test_num_finished(self):
@patch.object(pool.BaseTaskPool, 'num_cancelled', new_callable=PropertyMock) self.task_pool._num_cancelled = cancelled = 69
def test_num_finished(self, mock_num_cancelled: MagicMock, mock_num_ended: MagicMock): self.task_pool._num_ended = ended = 420
mock_num_cancelled.return_value = cancelled = 69 self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
mock_num_ended.return_value = ended = 420 self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
self.assertEqual(ended - cancelled, self.task_pool.num_finished)
def test_is_full(self): def test_is_full(self):
self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full) self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
@patch.object(pool.BaseTaskPool, 'max_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock)
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock,
mock_max_running: MagicMock):
def reset_mocks():
mock_is_full.reset_mock()
mock_num_running.reset_mock()
mock_max_running.reset_mock()
self._check_more_allowed_patcher.stop()
# Just reaching limit, we expect flag to become unset:
mock_is_full.return_value = False
mock_num_running.return_value = mock_max_running.return_value = 420
self.task_pool._more_allowed_flag.clear()
self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# Already at limit, we expect nothing to change:
mock_is_full.return_value = True
self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# Just finished a task, we expect flag to become set:
mock_num_running.return_value = 419
self.task_pool._check_more_allowed()
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_called_once_with()
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
reset_mocks()
# In this state we expect the flag to remain unchanged change:
mock_is_full.return_value = False
self.task_pool._check_more_allowed()
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
mock_max_running.assert_called_once_with()
def test__task_name(self): def test__task_name(self):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, '_execute_function')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, '_execute_function')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock()
mock__execute_function.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_wrapper(self, mock__task_name: MagicMock,
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
task_id = 42
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
# Cancelled during execution:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertIsNone(output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
mock_coroutine_func.reset_mock(side_effect=True)
mock__task_name.reset_mock()
mock__task_cancellation.reset_mock()
mock__task_ending.reset_mock()
# Not cancelled:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(FOO, output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_not_awaited()
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
mock_is_open.return_value = ignore_closed = False
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
ignore_closed = True
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
reset_mocks()
self.task_pool._counter = count
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
del self.task_pool._ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._running = {1: mock_task1, 2: mock_task2}
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_task1.cancel.assert_called_once_with(msg=FOO)
mock_task2.cancel.assert_called_once_with(msg=FOO)
async def test_flush(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._interrupt_flag.set()
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self):
mock_wait = AsyncMock()
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert self.task_pool._open
with self.assertRaises(exceptions.PoolStillOpen):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_not_awaited()
self.task_pool._open = False
def check_assertions() -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_awaited_once_with()
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()
mock_wait.reset_mock()
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()

View File

@ -33,9 +33,9 @@ async def work(n: int) -> None:
async def main() -> None: async def main() -> None:
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
pool.start(3) # launches work tasks 0, 1, and 2 await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.start() # launches work task 3 await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line pool.close() # required for the last line

View File

@ -44,7 +44,7 @@ async def main() -> None:
for item in range(100): for item in range(100):
q.put_nowait(item) q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool pool = SimpleTaskPool(worker, (q,)) # initializes the pool
pool.start(3) # launches three worker tasks await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue. # We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join() await q.join()