Compare commits

...

3 Commits

Author SHA1 Message Date
4c6a5412ca removed num_cancellations and num_finished from the pool interface; added num_cancelled; made the num argument non-optional in the start and stop methods of the SimpleTaskPool; changed some internals; improved docstrings 2022-03-17 13:52:02 +01:00
44c03cc493 catching exceptions in _queue_consumer meta task 2022-03-16 16:57:03 +01:00
689a74c678 control interface now supports TaskPool instances:
dotted paths to coroutine functions can be passed to the parser as arguments for methods like `map`;
parser supports literal evaluation for the argument iterables in methods like `map`;
minor fixes
2022-03-16 11:27:27 +01:00
9 changed files with 184 additions and 102 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.6.4
version = 0.8.0
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks
@ -9,7 +9,7 @@ long_description_content_type = text/markdown
keywords = asyncio, concurrency, tasks, coroutines, asynchronous, server
url = https://git.fajnberg.de/daniil/asyncio-taskpool
project_urls =
Bug Tracker = https://git.fajnberg.de/daniil/asyncio-taskpool/issues
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
classifiers =
Development Status :: 3 - Alpha
Programming Language :: Python :: 3

View File

@ -20,14 +20,16 @@ This module contains the the definition of the `ControlParser` class used by a c
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
from ast import literal_eval
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature
from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Set, Type, TypeVar
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..exceptions import HelpRequested, ParserError
from ..helpers import get_first_doc_line
from ..helpers import get_first_doc_line, resolve_dotted_path
from ..types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
@ -266,7 +268,7 @@ class ControlParser(ArgumentParser):
kwargs.setdefault('nargs', '*')
if not kwargs.get('action') == 'store_true':
# Set the type from the parameter annotation.
kwargs.setdefault('type', _get_arg_type_wrapper(parameter.annotation))
kwargs.setdefault('type', _get_type_from_annotation(parameter.annotation))
return self.add_argument(*name_or_flags, **kwargs)
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
@ -300,3 +302,11 @@ def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
wrapper.__name__ = cls.__name__
return wrapper
def _get_type_from_annotation(annotation: Type) -> Callable[[Any], Any]:
if any(annotation is t for t in {CoroutineFunc, EndCB, CancelCB}):
annotation = resolve_dotted_path
if any(annotation is t for t in {ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]}):
annotation = literal_eval
return _get_arg_type_wrapper(annotation)

View File

@ -37,7 +37,7 @@ from .session import ControlSession
log = logging.getLogger(__name__)
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
class ControlServer(ABC):
"""
Abstract base class for a task pool control server.

View File

@ -21,6 +21,7 @@ Miscellaneous helper functions. None of these should be considered part of the p
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from importlib import import_module
from inspect import getdoc
from typing import Any, Optional, Union
@ -63,3 +64,22 @@ async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwarg
return _function_to_execute(*args, **kwargs)
except Exception as e:
return e
def resolve_dotted_path(dotted_path: str) -> object:
"""
Resolves a dotted path to a global object and returns that object.
Algorithm shamelessly stolen from the `logging.config` module from the standard library.
"""
names = dotted_path.split('.')
module_name = names.pop(0)
found = import_module(module_name)
for name in names:
try:
found = getattr(found, name)
except AttributeError:
module_name += f'.{name}'
import_module(module_name)
found = getattr(found, name)
return found

View File

@ -59,16 +59,14 @@ class BaseTaskPool:
@classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
"""Adds a `pool` to the general list of pools and returns it's index."""
cls._pools.append(pool)
return len(cls._pools) - 1
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
# Initialize a counter for the total number of tasks started through the pool and one for the total number of
# tasks cancelled through the pool.
# Initialize a counter for the total number of tasks started through the pool.
self._num_started: int = 0
self._num_cancellations: int = 0
# Initialize flags; immutably set the name.
self._locked: bool = False
@ -97,30 +95,29 @@ class BaseTaskPool:
@property
def pool_size(self) -> int:
"""Returns the maximum number of concurrently running tasks currently set in the pool."""
return self._pool_size
"""Maximum number of concurrently running tasks allowed in the pool."""
return getattr(self._enough_room, '_value')
@pool_size.setter
def pool_size(self, value: int) -> None:
"""
Sets the maximum number of concurrently running tasks in the pool.
Args:
value:
A non-negative integer.
NOTE: Increasing the pool size will immediately start tasks that are awaiting enough room to run.
Args:
value: A non-negative integer.
Raises:
`ValueError` if `value` is less than 0.
"""
if value < 0:
raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value
self._pool_size = value
@property
def is_locked(self) -> bool:
"""Returns `True` if the pool has been locked (see below)."""
"""`True` if the pool has been locked (see below)."""
return self._locked
def lock(self) -> None:
@ -138,26 +135,26 @@ class BaseTaskPool:
@property
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
Number of tasks in the pool that are still running.
At the moment a task's `end_callback` or `cancel_callback` is fired, it is no longer considered running.
"""
return len(self._tasks_running)
@property
def num_cancellations(self) -> int:
def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
Number of tasks in the pool that have been cancelled.
At the moment a task's `cancel_callback` is fired, this counts as a cancellation, and the task is then
considered cancelled (instead of running) until its `end_callback` is fired.
At the moment a task's `cancel_callback` is fired, it is considered to be cancelled and no longer running,
until its `end_callback` is fired, at which point it is considered ended (instead of cancelled).
"""
return self._num_cancellations
return len(self._tasks_cancelled)
@property
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
Number of tasks in the pool that have stopped running.
At the moment a task's `end_callback` is fired, it is considered ended and no longer running (or cancelled).
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
@ -165,16 +162,12 @@ class BaseTaskPool:
"""
return len(self._tasks_ended)
@property
def num_finished(self) -> int:
"""Returns the number of tasks in the pool that have finished running (without having been cancelled)."""
return len(self._tasks_ended) - self._num_cancellations + len(self._tasks_cancelled)
@property
def is_full(self) -> bool:
"""
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
`False` if the number of running tasks is less than the `pool_size`.
When the pool is full, any call to start a new task within it will block, until there is enough room for it.
"""
return self._enough_room.locked()
@ -247,7 +240,6 @@ class BaseTaskPool:
"""
log.debug("Cancelling %s ...", self._task_name(task_id))
self._tasks_cancelled[task_id] = self._tasks_running.pop(task_id)
self._num_cancellations += 1
log.debug("Cancelled %s", self._task_name(task_id))
await execute_optional(custom_callback, args=(task_id,))
@ -276,7 +268,9 @@ class BaseTaskPool:
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCB = None,
cancel_callback: CancelCB = None) -> Any:
"""
Universal wrapper around every task run in the pool that returns/raises whatever the wrapped coroutine does.
Universal wrapper around every task run in the pool.
Returns/raises whatever the wrapped coroutine does.
Responsible for catching cancellation and awaiting the `_task_cancellation` callback, as well as for awaiting
the `_task_ending` callback, after the coroutine returns or raises an exception.
@ -381,7 +375,9 @@ class BaseTaskPool:
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
"""
Removes all tasks from the specified group and cancels them, if they are still running.
Removes all tasks from the specified group and cancels them.
Does nothing to tasks, that are no longer running.
Args:
group_name: The name of the group of tasks that shall be cancelled.
@ -397,7 +393,9 @@ class BaseTaskPool:
async def cancel_group(self, group_name: str, msg: str = None) -> None:
"""
Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
Cancels an entire group of tasks.
The task group is subsequently forgotten by the pool.
Args:
group_name: The name of the group of tasks that shall be cancelled.
@ -430,12 +428,13 @@ class BaseTaskPool:
async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
Calls `asyncio.gather` on all ended/cancelled tasks in the pool.
This method exists mainly to free up memory of unneeded `Task` objects.
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
`Task` objects.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
@ -446,7 +445,9 @@ class BaseTaskPool:
async def gather_and_close(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
Calls `asyncio.gather` on **all** tasks in the pool, then closes it.
After this method returns, no more tasks can be started in the pool.
The `lock()` method must have been called prior to this.
@ -473,7 +474,9 @@ class BaseTaskPool:
class TaskPool(BaseTaskPool):
"""
General task pool class. Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
General purpose task pool class.
Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
Tasks in the pool can all belong to the same coroutine function,
@ -506,12 +509,15 @@ class TaskPool(BaseTaskPool):
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
"""See base class."""
self._cancel_group_meta_tasks(group_name)
super()._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
async def cancel_group(self, group_name: str, msg: str = None) -> None:
"""
Cancels an entire group of tasks. The task group is subsequently forgotten by the pool.
Cancels an entire group of tasks.
The task group is subsequently forgotten by the pool.
If any methods such as `map()` launched meta tasks belonging to that group, these meta tasks are cancelled
before the actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
@ -529,7 +535,7 @@ class TaskPool(BaseTaskPool):
async def cancel_all(self, msg: str = None) -> None:
"""
Cancels all tasks still running within the pool. (This includes all meta tasks.)
Cancels all tasks still running within the pool (including meta tasks).
If any methods such as `map()` launched meta tasks, these meta tasks are cancelled before the actual tasks are
cancelled. This means that any tasks "queued" to be started by a meta task will **never even start**. In the
@ -569,12 +575,13 @@ class TaskPool(BaseTaskPool):
async def flush(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on all ended/cancelled tasks from the pool, and forgets the tasks.
Calls `asyncio.gather` on all ended/cancelled tasks in the pool.
This method exists mainly to free up memory of unneeded `Task` objects. It also gets rid of unneeded meta tasks.
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
`Task` objects. It also gets rid of unneeded meta tasks.
This method blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the
callbacks registered for the tasks block.
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
registered for the tasks block.
Args:
return_exceptions (optional): Passed directly into `gather`.
@ -587,7 +594,9 @@ class TaskPool(BaseTaskPool):
async def gather_and_close(self, return_exceptions: bool = False):
"""
Calls `asyncio.gather` on **all** tasks in the pool, then permanently closes the pool.
Calls `asyncio.gather` on **all** tasks in the pool, then closes it.
After this method returns, no more tasks can be started in the pool.
The `lock()` method must have been called prior to this.
@ -596,7 +605,6 @@ class TaskPool(BaseTaskPool):
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
this, make sure to call `cancel_all()` prior to this.
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
the callbacks registered for a task blocks for whatever reason.
@ -662,9 +670,13 @@ class TaskPool(BaseTaskPool):
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
Creates an arbitrary number of coroutines with the supplied arguments and runs them as new tasks in the pool.
Creates tasks with the supplied arguments to be run in the pool.
Each coroutine looks like `func(*args, **kwargs)`, meaning the `args` and `kwargs` are unpacked and passed
into `func` before creating each task, and this is done `num` times.
All the new tasks are added to the same task group.
Each coroutine looks like `func(*args, **kwargs)`. All the new tasks are added to the same task group.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
@ -775,15 +787,23 @@ class TaskPool(BaseTaskPool):
if next_arg is self._QUEUE_END_SENTINEL:
# The `_queue_producer()` either reached the last argument or was cancelled.
return
try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created.
# It does not imply an error within the task itself.
log.exception("%s occurred while trying to create task: %s(%s%s)",
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
map_semaphore.release()
async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Creates tasks in the pool with arguments from the supplied iterable.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
@ -840,10 +860,11 @@ class TaskPool(BaseTaskPool):
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
An asyncio-task-based equivalent of the `multiprocessing.pool.Pool.map` method.
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
All the new tasks are added to the same task group.
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
@ -939,6 +960,7 @@ class SimpleTaskPool(BaseTaskPool):
end_callback: EndCB = None, cancel_callback: CancelCB = None,
pool_size: int = inf, name: str = None) -> None:
"""
Initializes all required attributes.
Args:
func:
@ -957,6 +979,9 @@ class SimpleTaskPool(BaseTaskPool):
The maximum number of tasks allowed to run concurrently in the pool
name (optional):
An optional name for the pool.
Raises:
`NotCoroutine` if `func` is not a coroutine function.
"""
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
@ -969,7 +994,7 @@ class SimpleTaskPool(BaseTaskPool):
@property
def func_name(self) -> str:
"""Returns the name of the coroutine function used in the pool."""
"""Name of the coroutine function used in the pool."""
return self._func.__name__
async def _start_one(self) -> int:
@ -977,18 +1002,18 @@ class SimpleTaskPool(BaseTaskPool):
return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
async def start(self, num: int = 1) -> List[int]:
"""Starts `num` new tasks within the pool and returns their IDs as a list."""
async def start(self, num: int) -> List[int]:
"""Starts `num` new tasks within the pool and returns their IDs."""
ids = await gather(*(self._start_one() for _ in range(num)))
assert isinstance(ids, list) # for PyCharm (see above to-do-item)
assert isinstance(ids, list) # for PyCharm
return ids
def stop(self, num: int = 1) -> List[int]:
def stop(self, num: int) -> List[int]:
"""
Cancels `num` running tasks within the pool and returns their IDs as a list.
Cancels `num` running tasks within the pool and returns their IDs.
The tasks are canceled in LIFO order, meaning tasks started later will be stopped before those started earlier.
If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled.
If `num` is greater than or equal to the number of currently running tasks, all tasks are cancelled.
"""
ids = []
for i, task_id in enumerate(reversed(self._tasks_running)):
@ -999,5 +1024,5 @@ class SimpleTaskPool(BaseTaskPool):
return ids
def stop_all(self) -> List[int]:
"""Cancels all running tasks and returns their IDs as a list."""
"""Cancels all running tasks and returns their IDs."""
return self.stop(self.num_running)

View File

@ -20,12 +20,16 @@ Unittests for the `asyncio_taskpool.control.parser` module.
from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS
from ast import literal_eval
from inspect import signature
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
from typing import Iterable
from asyncio_taskpool.control import parser
from asyncio_taskpool.exceptions import HelpRequested, ParserError
from asyncio_taskpool.helpers import resolve_dotted_path
from asyncio_taskpool.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
FOO, BAR = 'foo', 'bar'
@ -194,11 +198,11 @@ class ControlServerTestCase(TestCase):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
@patch.object(parser, '_get_arg_type_wrapper')
@patch.object(parser, '_get_type_from_annotation')
@patch.object(parser.ArgumentParser, 'add_argument')
def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_arg_type_wrapper: MagicMock):
def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_type_from_annotation: MagicMock):
mock_add_argument.return_value = expected_output = 'action'
mock__get_arg_type_wrapper.return_value = mock_type = 'fake'
mock__get_type_from_annotation.return_value = mock_type = 'fake'
foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex
bar_default, baz_default, boo_default = 1, 0.1, 1j
@ -211,42 +215,42 @@ class ControlServerTestCase(TestCase):
kwargs = {FOO + BAR: 'xyz'}
self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs))
mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(foo_type)
mock__get_type_from_annotation.assert_called_once_with(foo_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
mock__get_type_from_annotation.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs))
mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(args_type)
mock__get_type_from_annotation.assert_called_once_with(args_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
mock__get_type_from_annotation.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs))
mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(bar_type)
mock__get_type_from_annotation.assert_called_once_with(bar_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
mock__get_type_from_annotation.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs))
mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(baz_type)
mock__get_type_from_annotation.assert_called_once_with(baz_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
mock__get_type_from_annotation.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs))
mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(boo_type)
mock__get_type_from_annotation.assert_called_once_with(boo_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
mock__get_type_from_annotation.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs))
mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs)
mock__get_arg_type_wrapper.assert_not_called()
mock__get_type_from_annotation.assert_not_called()
@patch.object(parser.ControlParser, 'add_function_arg')
def test_add_function_args(self, mock_add_function_arg: MagicMock):
@ -266,3 +270,20 @@ class RestTestCase(TestCase):
self.assertEqual('int', type_wrap.__name__)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13'))
@patch.object(parser, '_get_arg_type_wrapper')
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR
dotted_path_ann = [CoroutineFunc, EndCB, CancelCB]
literal_eval_ann = [ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]]
any_other_ann = MagicMock()
for a in dotted_path_ann:
self.assertEqual(expected_output, parser._get_type_from_annotation(a))
mock__get_arg_type_wrapper.assert_has_calls(len(dotted_path_ann) * [call(resolve_dotted_path)])
mock__get_arg_type_wrapper.reset_mock()
for a in literal_eval_ann:
self.assertEqual(expected_output, parser._get_type_from_annotation(a))
mock__get_arg_type_wrapper.assert_has_calls(len(literal_eval_ann) * [call(literal_eval)])
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, parser._get_type_from_annotation(any_other_ann))
mock__get_arg_type_wrapper.assert_called_once_with(any_other_ann)

View File

@ -20,7 +20,7 @@ Unittests for the `asyncio_taskpool.helpers` module.
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool import helpers
@ -118,3 +118,13 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(test_exception, output)
mock_func.assert_called_once_with(*args, **kwargs)
def test_resolve_dotted_path(self):
from logging import WARNING
from urllib.request import urlopen
self.assertEqual(WARNING, helpers.resolve_dotted_path('logging.WARNING'))
self.assertEqual(urlopen, helpers.resolve_dotted_path('urllib.request.urlopen'))
with patch.object(helpers, 'import_module', return_value=object) as mock_import_module:
with self.assertRaises(AttributeError):
helpers.resolve_dotted_path('foo.bar.baz')
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])

View File

@ -84,7 +84,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
def test_init(self):
self.assertEqual(0, self.task_pool._num_started)
self.assertEqual(0, self.task_pool._num_cancellations)
self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed)
@ -114,14 +113,14 @@ class BaseTaskPoolTestCase(CommonTestCase):
def test_pool_size(self):
self.pool_size_patcher.stop()
self.task_pool._pool_size = self.TEST_POOL_SIZE
self.task_pool._enough_room._value = self.TEST_POOL_SIZE
self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
self.assertEqual(new_size, self.task_pool._enough_room._value)
def test_is_locked(self):
self.task_pool._locked = FOO
@ -145,21 +144,14 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_running)
def test_num_cancellations(self):
self.task_pool._num_cancellations = 3
self.assertEqual(3, self.task_pool.num_cancellations)
def test_num_cancelled(self):
self.task_pool._tasks_cancelled = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self):
self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self):
self.task_pool._num_cancellations = num_cancellations = 69
num_ended = 420
self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
@ -200,12 +192,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancellations = cancelled = 3
self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._tasks_running)
self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -603,28 +593,34 @@ class TaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
@patch.object(pool, 'Semaphore')
@patch.object(pool.TaskPool, '_get_map_end_callback')
async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
@patch.object(pool, 'Semaphore')
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
mock__start_task: AsyncMock, mock_star_function: MagicMock):
mock__get_map_end_callback.return_value = map_cb = MagicMock()
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
mock_star_function.return_value = awaitable = 'totally an awaitable'
arg1, arg2 = 123456789, 'function argument'
mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable'
mock_star_function.side_effect = [awaitable, awaitable, Exception()]
arg1, arg2, bad = 123456789, 'function argument', None
mock_q_maxsize = 3
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]),
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
group_name, mock_func, stars = 'whatever', MagicMock(), 3
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3,
# at the end of the loop, we expect it be locked.
self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
call(mock_func, arg2, arg_stars=stars),
call(mock_func, bad, arg_stars=stars)
])
@patch.object(pool, 'create_task')

View File

@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Working example of a UnixControlServer in combination with the SimpleTaskPool.
Working example of a TCPControlServer in combination with the SimpleTaskPool.
Use the main CLI client to interface at the socket.
"""