Compare commits

..

7 Commits

25 changed files with 508 additions and 153 deletions

View File

@ -1,14 +1,13 @@
[run]
source = src/
branch = true
omit =
.venv/*
command_line = -m unittest discover
[report]
fail_under = 100
show_missing = True
skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
omit =
tests/*
if sys.version_info.+:

86
.github/workflows/main.yaml vendored Normal file
View File

@ -0,0 +1,86 @@
name: CI
on: [push]
jobs:
tests:
name: Python ${{ matrix.python-version }} Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- '3.8'
- '3.9'
- '3.10'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: 'requirements/dev.txt'
- name: Upgrade packaging tools
run: pip install -U pip
- name: Install dependencies
run: pip install -U -r requirements/dev.txt
- name: Install asyncio-taskpool
run: pip install -e .
- name: Run tests for Python ${{ matrix.python-version }}
if: ${{ matrix.python-version != '3.10' }}
run: python -m tests
- name: Run tests for Python 3.10 and save coverage
if: ${{ matrix.python-version == '3.10' }}
run: echo "coverage=$(./coverage.sh)" >> $GITHUB_ENV
outputs:
coverage: ${{ env.coverage }}
update_badges:
needs: tests
name: Update Badges
env:
meta_gist_id: 3f8240a976e8781a765d9c74a583dcda
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Download `cloc`
run: sudo apt-get update -y && sudo apt-get install -y cloc
- name: Count lines of code/comments
run: |
echo "cloc_code=$(./cloc.sh -c src/)" >> $GITHUB_ENV
echo "cloc_comments=$(./cloc.sh -m src/)" >> $GITHUB_ENV
echo "cloc_commentpercent=$(./cloc.sh -p src/)" >> $GITHUB_ENV
- name: Create badge for lines of code
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-code.json
label: Lines of Code
message: ${{ env.cloc_code }}
- name: Create badge for lines of comments
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-comments.json
label: Comments
message: ${{ env.cloc_comments }} (${{ env.cloc_commentpercent }}%)
- name: Create badge for test coverage
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: test-coverage.json
label: Coverage
message: ${{ needs.tests.outputs.coverage }}

11
.readthedocs.yaml Normal file
View File

@ -0,0 +1,11 @@
version: 2
build:
os: 'ubuntu-20.04'
tools:
python: '3.8'
python:
install:
- method: pip
path: .
sphinx:
fail_on_warning: true

View File

@ -1,7 +1,30 @@
[//]: # (This file is part of asyncio-taskpool.)
[//]: # (asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of)
[//]: # (version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.)
[//]: # (asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;)
[//]: # (without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.)
[//]: # (See the GNU Lesser General Public License for more details.)
[//]: # (You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.)
[//]: # (If not, see <https://www.gnu.org/licenses/>.)
# asyncio-taskpool
[![GitHub last commit][github-last-commit-img]][github-last-commit]
![Lines of code][gist-cloc-code-img]
![Lines of comments][gist-cloc-comments-img]
![Test coverage][gist-test-coverage-img]
[![License: LGPL v3.0][lgpl3-img]][lgpl3]
[![PyPI version][pypi-latest-version-img]][pypi-latest-version]
**Dynamically manage pools of asyncio tasks**
Full documentation available at [RtD](https://asyncio-taskpool.readthedocs.io/en/latest).
---
## Contents
- [Contents](#contents)
- [Summary](#summary)
@ -55,8 +78,7 @@ Python Version 3.8+, tested on Linux
## Testing
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
Install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`, then execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and save the coverage report.
## License
@ -67,3 +89,13 @@ The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COP
---
© 2022 Daniil Fajnberg
[github-last-commit]: https://github.com/daniil-berg/asyncio-taskpool/commits
[github-last-commit-img]: https://img.shields.io/github/last-commit/daniil-berg/asyncio-taskpool?label=Last%20commit&logo=git&
[gist-cloc-code-img]: https://img.shields.io/endpoint?logo=python&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-code.json
[gist-cloc-comments-img]: https://img.shields.io/endpoint?logo=sharp&color=lightgrey&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-comments.json
[gist-test-coverage-img]: https://img.shields.io/endpoint?logo=pytest&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/test-coverage.json
[lgpl3]: https://www.gnu.org/licenses/lgpl-3.0
[lgpl3-img]: https://img.shields.io/badge/License-LGPL_v3.0-darkgreen.svg?logo=gnu
[pypi-latest-version-img]: https://img.shields.io/pypi/v/asyncio-taskpool?color=teal&logo=pypi
[pypi-latest-version]: https://pypi.org/project/asyncio-taskpool/

46
cloc.sh Executable file
View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
# This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
typeset option
if getopts 'bcmp' option; then
if [[ ${option} == [bcmp] ]]; then
shift
else
echo >&2 "Invalid option '$1' provided"
exit 1
fi
fi
typeset source=$1
if [[ -z ${source} ]]; then
echo >&2 Source file/directory missing
exit 1
fi
typeset blank code comment commentpercent
read blank comment code commentpercent < <( \
cloc --csv --quiet --hide-rate --include-lang Python ${source} |
awk -F, '$2 == "SUM" {printf ("%d %d %d %1.0f", $3, $4, $5, 100 * $4 / ($5 + $4)); exit}'
)
case ${option} in
b) echo ${blank} ;;
c) echo ${code} ;;
m) echo ${comment} ;;
p) echo ${commentpercent} ;;
*) echo Blank lines: ${blank}
echo Lines of comments: ${comment}
echo Lines of code: ${code}
echo Comment percentage: ${commentpercent} ;;
esac

View File

@ -1,3 +1,25 @@
#!/usr/bin/env sh
#!/usr/bin/env bash
coverage erase && coverage run -m unittest discover && coverage report
# This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
coverage erase
coverage run 2> /dev/null
typeset report=$(coverage report)
typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
if [[ ${total} == 100% ]]; then
echo ${total}
else
echo "${report}"
fi

View File

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.parser module
=======================================
.. automodule:: asyncio_taskpool.control.parser
:members:
:undoc-members:
:show-inheritance:

View File

@ -13,6 +13,4 @@ Submodules
:maxdepth: 4
asyncio_taskpool.control.client
asyncio_taskpool.control.parser
asyncio_taskpool.control.server
asyncio_taskpool.control.session

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.session module
========================================
.. automodule:: asyncio_taskpool.control.session
:members:
:undoc-members:
:show-inheritance:

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 1.0.0-beta
version = 1.0.1
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks
@ -11,7 +11,7 @@ url = https://git.fajnberg.de/daniil/asyncio-taskpool
project_urls =
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
classifiers =
Development Status :: 4 - Beta
Development Status :: 5 - Production/Stable
Programming Language :: Python :: 3
Operating System :: OS Independent
License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)

View File

@ -97,21 +97,22 @@ class ControlClient(ABC):
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
`None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
converted to lowercase.
"""
try:
msg = input("> ").strip().lower()
cmd = input("> ").strip().lower()
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
msg = CLIENT_EXIT
cmd = CLIENT_EXIT
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print()
return
if msg == CLIENT_EXIT:
if cmd == CLIENT_EXIT:
writer.close()
self._connected = False
return
return msg
return cmd or None # will be None if `cmd` is an empty string
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
"""

View File

@ -17,19 +17,21 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Definition of the :class:`ControlParser` used in a
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
It should not be considered part of the public API.
"""
import logging
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
from ast import literal_eval
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature
from io import StringIO
from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
from ..exceptions import HelpRequested, ParserError
from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..internals.constants import CLIENT_INFO, CMD
from ..internals.helpers import get_first_doc_line, resolve_dotted_path
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
@ -52,8 +54,8 @@ class ControlParser(ArgumentParser):
"""
Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
instance passed to it during initialization.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
`StringIO` instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client.
Finally, it offers some convenience methods and makes use of custom exceptions.
@ -87,25 +89,23 @@ class ControlParser(ArgumentParser):
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None:
def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
"""
Sets some internal attributes in addition to the base class.
Args:
stream_writer:
The instance of the :class:`asyncio.StreamWriter` to use for message output.
stream:
A file-like I/O object to use for message output.
terminal_width (optional):
The terminal width to use for all message formatting. By default the :code:`columns` attribute from
:func:`shutil.get_terminal_size` is taken.
**kwargs(optional):
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
"""
self._stream_writer: StreamWriter = stream_writer
self._stream: StringIO = stream
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs)
self._flags: Set[str] = set()
self._commands = None
@ -194,7 +194,7 @@ class ControlParser(ArgumentParser):
Dictionary mapping class member names to the (sub-)parsers created from them.
"""
parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only):
continue
@ -214,9 +214,9 @@ class ControlParser(ArgumentParser):
return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
if message:
self._stream_writer.write(message.encode())
self._stream.write(message)
def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked."""

View File

@ -31,6 +31,7 @@ from typing import Optional, Union
from .client import ControlClient, TCPControlClient, UnixControlClient
from .session import ControlSession
from ..pool import AnyTaskPoolT
from ..internals.helpers import classmethod
from ..internals.types import ConnectedCallbackT, PathT

View File

@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
It should not be considered part of the public API.
"""
@ -24,12 +26,13 @@ import json
from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature
from io import StringIO
from typing import Callable, Optional, Union, TYPE_CHECKING
from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES
from ..internals.helpers import return_or_exception
if TYPE_CHECKING:
@ -72,6 +75,7 @@ class ControlSession:
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[ControlParser] = None
self._response_buffer: StringIO = StringIO()
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
"""
@ -133,7 +137,7 @@ class ControlSession:
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
parser_kwargs = {
STREAM_WRITER: self._writer,
'stream': self._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '',
'usage': f'[-h] [{CMD}] ...'
@ -160,7 +164,7 @@ class ControlSession:
kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode())
self._response_buffer.write(str(e))
return
except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name)
@ -171,7 +175,7 @@ class ControlSession:
elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs)
else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode())
self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
async def listen(self) -> None:
"""
@ -188,4 +192,8 @@ class ControlSession:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
response = self._response_buffer.getvalue()
self._response_buffer.seek(0)
self._response_buffer.truncate()
self._writer.write(response.encode())
await self._writer.drain()

View File

@ -27,7 +27,6 @@ DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer'
CMD = 'command'
CMD_OK = b"ok"

View File

@ -19,10 +19,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
"""
import builtins
import sys
from asyncio.coroutines import iscoroutinefunction
from importlib import import_module
from inspect import getdoc
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Type, Union
from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -131,3 +133,25 @@ def resolve_dotted_path(dotted_path: str) -> object:
import_module(module_name)
found = getattr(found, name)
return found
class ClassMethodWorkaround:
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
def __init__(self, method_or_property: Union[Callable, property]) -> None:
if isinstance(method_or_property, property):
self._getter = method_or_property.fget
else:
self._getter = method_or_property
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
if obj is None:
return self._getter(cls)
return self._getter(obj)
# Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9):
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -23,7 +23,7 @@ This module should **not** be considered part of the public API.
from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
@ -31,8 +31,8 @@ T = TypeVar('T')
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
AnyCallableT = Callable[..., Union[T, Awaitable[T]]]
CoroutineFunc = Callable[..., Coroutine]
EndCB = Callable
CancelCB = Callable

View File

@ -326,6 +326,8 @@ class BaseTaskPool:
"""
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
await self._enough_room.acquire()
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
# Possibly make use of the task group register for that.
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
async with group_reg:
task_id = self._num_started
@ -528,6 +530,8 @@ class BaseTaskPool:
self._tasks_cancelled.clear()
self._tasks_running.clear()
self._closed = True
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server.
class TaskPool(BaseTaskPool):
@ -566,36 +570,50 @@ class TaskPool(BaseTaskPool):
return name
i += 1
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
Args:
group_name:
Name of the task group to add the new task to.
Name of the task group to add the new tasks to.
func:
The coroutine function to be run as a task within the task pool.
The coroutine function to be run in `num` tasks within the task pool.
args (optional):
The positional arguments to pass into the function call.
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into the function call.
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
end_callback (optional):
A callback to execute after the task has ended.
A callback to execute after each task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of the task.
A callback to execute after cancellation of each task.
It is run with the task's ID as its only positional argument.
"""
if kwargs is None:
kwargs = {}
# TODO: Add exception logging
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) for _ in range(num)))
for i in range(num):
try:
coroutine = func(*args, **kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
@ -622,7 +640,7 @@ class TaskPool(BaseTaskPool):
kwargs (optional):
The keyword-arguments to pass into each function call.
num (optional):
The number of tasks to spawn with the specified parameters.
The number of tasks to spawn with the specified parameters. Defaults to 1.
group_name (optional):
Name of the task group to add the new tasks to. By default, a unique name is constructed in the form
:code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an
@ -650,7 +668,7 @@ class TaskPool(BaseTaskPool):
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
end_callback=end_callback, cancel_callback=cancel_callback)))
return group_name
@ -690,27 +708,31 @@ class TaskPool(BaseTaskPool):
The callback that was specified to execute after cancellation of the task (and the next one).
It is run with the task's ID as its only positional argument.
"""
map_semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
for next_arg in arg_iter:
semaphore = Semaphore(num_concurrent)
release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
for i, next_arg in enumerate(arg_iter):
semaphore_acquired = False
try:
coroutine = star_function(func, next_arg, arg_stars=arg_stars)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
continue
try:
# When the number of running tasks spawned by this method reaches the specified maximum,
# this next line will block, until one of them ends and releases the semaphore.
await map_semaphore.acquire()
try:
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
semaphore_acquired = await semaphore.acquire()
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
end_callback=release_cb, cancel_callback=cancel_callback)
except CancelledError:
# This means that no more tasks are supposed to be created from this `arg_iter`;
# thus, we can forget about the rest of the arguments.
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
map_semaphore.release()
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
coroutine.close()
if semaphore_acquired:
semaphore.release()
return
except Exception as e:
# This means an exception occurred during task **creation**, meaning no task has been created.
# It does not imply an error within the task itself.
log.exception("%s occurred while trying to create task: %s(%s%s)",
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
map_semaphore.release()
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
@ -924,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool):
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num)
)
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines)
def start(self, num: int) -> str:

30
tests/__main__.py Normal file
View File

@ -0,0 +1,30 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Main entry point for all unit tests.
"""
import sys
import unittest
if __name__ == '__main__':
test_suite = unittest.defaultTestLoader.discover('.')
test_runner = unittest.TextTestRunner(resultclass=unittest.TextTestResult)
result = test_runner.run(test_suite)
sys.exit(not result.wasSuccessful())

View File

@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs)

View File

@ -41,9 +41,9 @@ class ControlParserTestCase(TestCase):
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420
self.stream, self.terminal_width = MagicMock(), 420
self.kwargs = {
'stream_writer': self.stream_writer,
'stream': self.stream,
'terminal_width': self.terminal_width,
'formatter_class': FOO
}
@ -72,10 +72,9 @@ class ControlParserTestCase(TestCase):
def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer)
self.assertEqual(self.stream, self.parser._stream)
self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands)
@ -89,7 +88,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output)
@ -107,7 +106,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop)
@ -119,7 +118,7 @@ class ControlParserTestCase(TestCase):
prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
@ -152,8 +151,7 @@ class ControlParserTestCase(TestCase):
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer,
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output)
@ -170,12 +168,12 @@ class ControlParserTestCase(TestCase):
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self):
self.stream_writer.write = MagicMock()
self.stream.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called()
self.stream.write.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode())
self.stream.write.assert_called_once_with(msg)
@patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock):

View File

@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
import json
from argparse import ArgumentError, Namespace
from io import StringIO
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool
@ -61,14 +62,15 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
self.assertIsInstance(self.session._response_buffer, StringIO)
@patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs))
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
)
@ -104,7 +106,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer,
'stream': self.session._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '',
'usage': f'[-h] [{CMD}] ...'
@ -132,10 +134,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called()
@ -145,7 +146,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
@ -161,26 +162,28 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode())
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
@ -191,17 +194,23 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO
self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode())
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited()

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.helpers` module.
"""
from unittest import IsolatedAsyncioTestCase
from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import helpers
@ -122,3 +122,33 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(AttributeError):
helpers.resolve_dotted_path('foo.bar.baz')
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
class ClassMethodWorkaroundTestCase(TestCase):
def test_init(self):
def func(): return 'foo'
def getter(): return 'bar'
prop = property(getter)
instance = helpers.ClassMethodWorkaround(func)
self.assertIs(func, instance._getter)
instance = helpers.ClassMethodWorkaround(prop)
self.assertIs(getter, instance._getter)
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
def test_get(self, _mock_init: MagicMock):
def func(x: MagicMock): return x.__name__
instance = helpers.ClassMethodWorkaround(MagicMock())
instance._getter = func
obj, cls = None, MagicMock
expected_output = 'MagicMock'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
obj = MagicMock(__name__='bar')
expected_output = 'bar'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)

View File

@ -455,48 +455,64 @@ class TaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output)
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock):
group_name = FOO + BAR
mock_awaitable = object()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
async def test__apply_spawner(self, mock__start_task: AsyncMock):
grp_name = FOO + BAR
mock_awaitable1, mock_awaitable2 = object(), object()
mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
mock__start_task.assert_has_awaits(3 * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args, **kw)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])
mock_func.reset_mock()
mock_func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args)])
mock__start_task.assert_has_awaits(num * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(2 * [call(*args)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
@patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start')
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object()
mock__apply_spawner.return_value = mock_apply_coroutine = object()
mock_create_task.return_value = fake_task = object()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._task_groups = {group_name: 'causes error'}
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=mock_func)
mock__apply_spawner.assert_not_called()
mock_create_task.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups = {}
def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
@ -507,7 +523,7 @@ class TaskPoolTestCase(CommonTestCase):
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
mock__apply_num.reset_mock()
mock__apply_spawner.reset_mock()
mock_create_task.reset_mock()
output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
@ -532,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase):
n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable'
mock_star_function.side_effect = [awaitable, Exception(), awaitable]
awaitable1, awaitable2 = 'totally an awaitable', object()
mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
arg1, arg2, bad = 123456789, 'function argument', None
args = [arg1, bad, arg2]
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
mock__start_task.assert_has_awaits([
call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
])
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
@ -556,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase):
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting a task:
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
mock_star_function.side_effect = CancelledError()
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
self.assertFalse(semaphore.locked())
# With a CancelledError thrown while acquiring the semaphore:
mock_acquire = AsyncMock(side_effect=[True, CancelledError])
mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
mock_star_function.return_value = mock_coroutine = MagicMock()
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called()
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
end_callback=map_cb, cancel_callback=cancel_cb)
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_not_called()
self.assertEqual(FOO, next(arg_it))
mock_acquire.reset_mock(side_effect=True)
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting the task:
mock__start_task.side_effect = [None, CancelledError]
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_has_awaits(2 * [
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_called_once_with()
self.assertEqual(FOO, next(arg_it))
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
@ -680,13 +729,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | {
params = super().get_task_pool_init_params()
params.update({
'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB,
}
})
return params
def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
@ -695,6 +746,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)