Compare commits

..

13 Commits

30 changed files with 558 additions and 130 deletions

View File

@ -1,14 +1,12 @@
[run] [run]
source = src/ source = src/
branch = true branch = true
omit = command_line = -m unittest discover
.venv/*
[report] [report]
fail_under = 100
show_missing = True show_missing = True
skip_covered = False skip_covered = False
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == ['"]__main__['"]: if __name__ == ['"]__main__['"]:
omit =
tests/*

88
.github/workflows/main.yaml vendored Normal file
View File

@ -0,0 +1,88 @@
name: CI
on:
push:
branches: [master]
jobs:
tests:
name: Python ${{ matrix.python-version }} Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- '3.8'
- '3.9'
- '3.10'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: 'requirements/dev.txt'
- name: Upgrade packaging tools
run: pip install -U pip
- name: Install dependencies
run: pip install -U -r requirements/dev.txt
- name: Install asyncio-taskpool
run: pip install -e .
- name: Run tests for Python ${{ matrix.python-version }}
if: ${{ matrix.python-version != '3.10' }}
run: python -m tests
- name: Run tests for Python 3.10 and save coverage
if: ${{ matrix.python-version == '3.10' }}
run: echo "coverage=$(./coverage.sh)" >> $GITHUB_ENV
outputs:
coverage: ${{ env.coverage }}
update_badges:
needs: tests
name: Update Badges
env:
meta_gist_id: 3f8240a976e8781a765d9c74a583dcda
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Download `cloc`
run: sudo apt-get update -y && sudo apt-get install -y cloc
- name: Count lines of code/comments
run: |
echo "cloc_code=$(./cloc.sh -c src/)" >> $GITHUB_ENV
echo "cloc_comments=$(./cloc.sh -m src/)" >> $GITHUB_ENV
echo "cloc_commentpercent=$(./cloc.sh -p src/)" >> $GITHUB_ENV
- name: Create badge for lines of code
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-code.json
label: Lines of Code
message: ${{ env.cloc_code }}
- name: Create badge for lines of comments
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-comments.json
label: Comments
message: ${{ env.cloc_comments }} (${{ env.cloc_commentpercent }}%)
- name: Create badge for test coverage
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: test-coverage.json
label: Coverage
message: ${{ needs.tests.outputs.coverage }}

11
.readthedocs.yaml Normal file
View File

@ -0,0 +1,11 @@
version: 2
build:
os: 'ubuntu-20.04'
tools:
python: '3.8'
python:
install:
- method: pip
path: .
sphinx:
fail_on_warning: true

View File

@ -1,7 +1,30 @@
[//]: # (This file is part of asyncio-taskpool.)
[//]: # (asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of)
[//]: # (version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.)
[//]: # (asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;)
[//]: # (without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.)
[//]: # (See the GNU Lesser General Public License for more details.)
[//]: # (You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.)
[//]: # (If not, see <https://www.gnu.org/licenses/>.)
# asyncio-taskpool # asyncio-taskpool
[![GitHub last commit][github-last-commit-img]][github-last-commit]
![Lines of code][gist-cloc-code-img]
![Lines of comments][gist-cloc-comments-img]
![Test coverage][gist-test-coverage-img]
[![License: LGPL v3.0][lgpl3-img]][lgpl3]
[![PyPI version][pypi-latest-version-img]][pypi-latest-version]
**Dynamically manage pools of asyncio tasks** **Dynamically manage pools of asyncio tasks**
Full documentation available at [RtD](https://asyncio-taskpool.readthedocs.io/en/latest).
---
## Contents ## Contents
- [Contents](#contents) - [Contents](#contents)
- [Summary](#summary) - [Summary](#summary)
@ -55,8 +78,7 @@ Python Version 3.8+, tested on Linux
## Testing ## Testing
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`. Install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`, then execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and save the coverage report.
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
## License ## License
@ -67,3 +89,13 @@ The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COP
--- ---
© 2022 Daniil Fajnberg © 2022 Daniil Fajnberg
[github-last-commit]: https://github.com/daniil-berg/asyncio-taskpool/commits
[github-last-commit-img]: https://img.shields.io/github/last-commit/daniil-berg/asyncio-taskpool?label=Last%20commit&logo=git&
[gist-cloc-code-img]: https://img.shields.io/endpoint?logo=python&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-code.json
[gist-cloc-comments-img]: https://img.shields.io/endpoint?logo=sharp&color=lightgrey&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-comments.json
[gist-test-coverage-img]: https://img.shields.io/endpoint?logo=pytest&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/test-coverage.json
[lgpl3]: https://www.gnu.org/licenses/lgpl-3.0
[lgpl3-img]: https://img.shields.io/badge/License-LGPL_v3.0-darkgreen.svg?logo=gnu
[pypi-latest-version-img]: https://img.shields.io/pypi/v/asyncio-taskpool?color=teal&logo=pypi
[pypi-latest-version]: https://pypi.org/project/asyncio-taskpool/

46
cloc.sh Executable file
View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
# This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
typeset option
if getopts 'bcmp' option; then
if [[ ${option} == [bcmp] ]]; then
shift
else
echo >&2 "Invalid option '$1' provided"
exit 1
fi
fi
typeset source=$1
if [[ -z ${source} ]]; then
echo >&2 Source file/directory missing
exit 1
fi
typeset blank code comment commentpercent
read blank comment code commentpercent < <( \
cloc --csv --quiet --hide-rate --include-lang Python ${source} |
awk -F, '$2 == "SUM" {printf ("%d %d %d %1.0f", $3, $4, $5, 100 * $4 / ($5 + $4)); exit}'
)
case ${option} in
b) echo ${blank} ;;
c) echo ${code} ;;
m) echo ${comment} ;;
p) echo ${commentpercent} ;;
*) echo Blank lines: ${blank}
echo Lines of comments: ${comment}
echo Lines of code: ${code}
echo Comment percentage: ${commentpercent} ;;
esac

View File

@ -1,3 +1,25 @@
#!/usr/bin/env sh #!/usr/bin/env bash
coverage erase && coverage run -m unittest discover && coverage report # This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
coverage erase
coverage run 2> /dev/null
typeset report=$(coverage report)
typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
if [[ ${total} == 100% ]]; then
echo ${total}
else
echo "${report}"
fi

View File

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.parser module
=======================================
.. automodule:: asyncio_taskpool.control.parser
:members:
:undoc-members:
:show-inheritance:

View File

@ -13,6 +13,4 @@ Submodules
:maxdepth: 4 :maxdepth: 4
asyncio_taskpool.control.client asyncio_taskpool.control.client
asyncio_taskpool.control.parser
asyncio_taskpool.control.server asyncio_taskpool.control.server
asyncio_taskpool.control.session

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.session module
========================================
.. automodule:: asyncio_taskpool.control.session
:members:
:undoc-members:
:show-inheritance:

View File

@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg' author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '1.0.0-beta' release = '1.1.1'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -45,6 +45,7 @@ Contents
:maxdepth: 2 :maxdepth: 2
pages/pool pages/pool
pages/ids
pages/control pages/control
api/api api/api

42
docs/source/pages/ids.rst Normal file
View File

@ -0,0 +1,42 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
IDs, groups & names
===================
Task IDs
--------
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
Task groups
-----------
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
Pool names
----------
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.

View File

@ -87,7 +87,7 @@ By contrast, here is how you would do it with a task pool:
... ...
await pool.flush() await pool.flush()
Pretty much self-explanatory, no? Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look: Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 1.0.0 version = 1.1.1
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -35,7 +35,7 @@ __all__ = []
CLIENT_CLASS = 'client_class' CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp' UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path' SOCKET_PATH = 'socket_path'
HOST, PORT = 'host', 'port' HOST, PORT = 'host', 'port'

View File

@ -97,21 +97,22 @@ class ControlClient(ABC):
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns: Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect; `None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase. client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
converted to lowercase.
""" """
try: try:
msg = input("> ").strip().lower() cmd = input("> ").strip().lower()
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command. except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
msg = CLIENT_EXIT cmd = CLIENT_EXIT
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt. except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print() print()
return return
if msg == CLIENT_EXIT: if cmd == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
return msg return cmd or None # will be None if `cmd` is an empty string
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
""" """

View File

@ -17,19 +17,21 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Definition of the :class:`ControlParser` used in a Definition of the :class:`ControlParser` used in a
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`. :class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
It should not be considered part of the public API.
""" """
import logging import logging
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
from ast import literal_eval from ast import literal_eval
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature from inspect import Parameter, getmembers, isfunction, signature
from io import StringIO
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
from ..exceptions import HelpRequested, ParserError from ..exceptions import HelpRequested, ParserError
from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER from ..internals.constants import CLIENT_INFO, CMD
from ..internals.helpers import get_first_doc_line, resolve_dotted_path from ..internals.helpers import get_first_doc_line, resolve_dotted_path
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
@ -52,8 +54,8 @@ class ControlParser(ArgumentParser):
""" """
Subclass of the standard :code:`argparse.ArgumentParser` for pool control. Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
instance passed to it during initialization. `StringIO` instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client. connected client.
Finally, it offers some convenience methods and makes use of custom exceptions. Finally, it offers some convenience methods and makes use of custom exceptions.
@ -87,25 +89,23 @@ class ControlParser(ArgumentParser):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
return ClientHelpFormatter return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None: def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
""" """
Sets some internal attributes in addition to the base class. Sets some internal attributes in addition to the base class.
Args: Args:
stream_writer: stream:
The instance of the :class:`asyncio.StreamWriter` to use for message output. A file-like I/O object to use for message output.
terminal_width (optional): terminal_width (optional):
The terminal width to use for all message formatting. By default the :code:`columns` attribute from The terminal width to use for all message formatting. By default the :code:`columns` attribute from
:func:`shutil.get_terminal_size` is taken. :func:`shutil.get_terminal_size` is taken.
**kwargs(optional): **kwargs(optional):
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`. class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
""" """
self._stream_writer: StreamWriter = stream_writer self._stream: StringIO = stream
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class')) kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs) super().__init__(**kwargs)
self._flags: Set[str] = set() self._flags: Set[str] = set()
self._commands = None self._commands = None
@ -194,7 +194,7 @@ class ControlParser(ArgumentParser):
Dictionary mapping class member names to the (sub-)parsers created from them. Dictionary mapping class member names to the (sub-)parsers created from them.
""" """
parsers: ParsersDict = {} parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width} common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls): for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only): if name in omit_members or (name.startswith('_') and public_only):
continue continue
@ -214,9 +214,9 @@ class ControlParser(ArgumentParser):
return self._commands return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None: def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
if message: if message:
self._stream_writer.write(message.encode()) self._stream.write(message)
def exit(self, status: int = 0, message: str = None) -> None: def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked.""" """This is overridden to prevent system exit to be invoked."""

View File

@ -31,6 +31,7 @@ from typing import Optional, Union
from .client import ControlClient, TCPControlClient, UnixControlClient from .client import ControlClient, TCPControlClient, UnixControlClient
from .session import ControlSession from .session import ControlSession
from ..pool import AnyTaskPoolT from ..pool import AnyTaskPoolT
from ..internals.helpers import classmethod
from ..internals.types import ConnectedCallbackT, PathT from ..internals.types import ConnectedCallbackT, PathT

View File

@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Definition of the :class:`ControlSession` used by a :class:`ControlServer`. Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
It should not be considered part of the public API.
""" """
@ -24,12 +26,13 @@ import json
from argparse import ArgumentError from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature from inspect import isfunction, signature
from io import StringIO
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from .parser import ControlParser from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES
from ..internals.helpers import return_or_exception from ..internals.helpers import return_or_exception
if TYPE_CHECKING: if TYPE_CHECKING:
@ -72,6 +75,7 @@ class ControlSession:
self._reader: StreamReader = reader self._reader: StreamReader = reader
self._writer: StreamWriter = writer self._writer: StreamWriter = writer
self._parser: Optional[ControlParser] = None self._parser: Optional[ControlParser] = None
self._response_buffer: StringIO = StringIO()
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None: async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
""" """
@ -133,7 +137,7 @@ class ControlSession:
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
parser_kwargs = { parser_kwargs = {
STREAM_WRITER: self._writer, 'stream': self._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH], CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -160,7 +164,7 @@ class ControlSession:
kwargs = vars(self._parser.parse_args(msg.split(' '))) kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e: except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name) log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode()) self._response_buffer.write(str(e))
return return
except (HelpRequested, ParserError): except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name) log.debug("%s received usage help", self._client_class_name)
@ -171,7 +175,7 @@ class ControlSession:
elif isinstance(command, property): elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs) await self._exec_property_and_respond(command, **kwargs)
else: else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode()) self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
async def listen(self) -> None: async def listen(self) -> None:
""" """
@ -188,4 +192,8 @@ class ControlSession:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break
await self._parse_command(msg) await self._parse_command(msg)
response = self._response_buffer.getvalue()
self._response_buffer.seek(0)
self._response_buffer.truncate()
self._writer.write(response.encode())
await self._writer.drain() await self._writer.drain()

View File

@ -21,13 +21,17 @@ This module should **not** be considered part of the public API.
""" """
import sys
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DEFAULT_TASK_GROUP = 'default' DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer'
CMD = 'command' CMD = 'command'
CMD_OK = b"ok" CMD_OK = b"ok"

View File

@ -19,11 +19,13 @@ Miscellaneous helper functions. None of these should be considered part of the p
""" """
import builtins
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -131,3 +133,25 @@ def resolve_dotted_path(dotted_path: str) -> object:
import_module(module_name) import_module(module_name)
found = getattr(found, name) found = getattr(found, name)
return found return found
class ClassMethodWorkaround:
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
def __init__(self, method_or_property: Union[Callable, property]) -> None:
if isinstance(method_or_property, property):
self._getter = method_or_property.fget
else:
self._getter = method_or_property
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
if obj is None:
return self._getter(cls)
return self._getter(obj)
# Starting with Python 3.9, this is thankfully no longer necessary.
if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -31,8 +31,8 @@ T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] AnyCallableT = Callable[..., Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Coroutine] CoroutineFunc = Callable[..., Coroutine]
EndCB = Callable EndCB = Callable
CancelCB = Callable CancelCB = Callable

View File

@ -28,16 +28,17 @@ For further details about the classes check their respective documentation.
import logging import logging
import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task, gather from asyncio.tasks import Task, create_task, gather
from contextlib import suppress from contextlib import suppress
from math import inf from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -71,7 +72,7 @@ class BaseTaskPool:
# Initialize flags; immutably set the name. # Initialize flags; immutably set the name.
self._locked: bool = False self._locked: bool = False
self._closed: bool = False self._closed: Event = Event()
self._name: str = name self._name: str = name
# The following three dictionaries are the actual containers of the tasks controlled by the pool. # The following three dictionaries are the actual containers of the tasks controlled by the pool.
@ -220,7 +221,7 @@ class BaseTaskPool:
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if function and not iscoroutinefunction(function): if function and not iscoroutinefunction(function):
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}") raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
if self._closed: if self._closed.is_set():
raise exceptions.PoolIsClosed("You must use another pool") raise exceptions.PoolIsClosed("You must use another pool")
if self._locked and not ignore_lock: if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
@ -360,6 +361,23 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@staticmethod
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
"""
Returns a dictionary to unpack in a `Task.cancel()` method.
This method exists to ensure proper compatibility with older Python versions.
If `msg` is `None`, an empty dictionary is returned.
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
Otherwise the keyword dictionary contains the `msg` parameter.
"""
if msg is None:
return {}
if PYTHON_BEFORE_39:
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
return {}
return {'msg': msg}
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
""" """
Cancels the tasks with the specified IDs. Cancels the tasks with the specified IDs.
@ -378,8 +396,9 @@ class BaseTaskPool:
`InvalidTaskID`: One of the `task_ids` is not known to the pool. `InvalidTaskID`: One of the `task_ids` is not known to the pool.
""" """
tasks = [self._get_running_task(task_id) for task_id in task_ids] tasks = [self._get_running_task(task_id) for task_id in task_ids]
kw = self._get_cancel_kw(msg)
for task in tasks: for task in tasks:
task.cancel(msg=msg) task.cancel(**kw)
def _cancel_group_meta_tasks(self, group_name: str) -> None: def _cancel_group_meta_tasks(self, group_name: str) -> None:
"""Cancels and forgets all meta tasks associated with the task group named `group_name`.""" """Cancels and forgets all meta tasks associated with the task group named `group_name`."""
@ -392,7 +411,7 @@ class BaseTaskPool:
self._meta_tasks_cancelled.update(meta_tasks) self._meta_tasks_cancelled.update(meta_tasks)
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name) log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None: def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
""" """
Removes all tasks from the specified group and cancels them. Removes all tasks from the specified group and cancels them.
@ -406,7 +425,7 @@ class BaseTaskPool:
self._cancel_group_meta_tasks(group_name) self._cancel_group_meta_tasks(group_name)
while group_reg: while group_reg:
try: try:
self._tasks_running[group_reg.pop()].cancel(msg=msg) self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
except KeyError: except KeyError:
continue continue
log.debug("%s cancelled tasks from group %s", str(self), group_name) log.debug("%s cancelled tasks from group %s", str(self), group_name)
@ -433,7 +452,8 @@ class BaseTaskPool:
group_reg = self._task_groups.pop(group_name) group_reg = self._task_groups.pop(group_name)
except KeyError: except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) kw = self._get_cancel_kw(msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
log.debug("%s forgot task group %s", str(self), group_name) log.debug("%s forgot task group %s", str(self), group_name)
def cancel_all(self, msg: str = None) -> None: def cancel_all(self, msg: str = None) -> None:
@ -448,9 +468,10 @@ class BaseTaskPool:
msg (optional): Passed to the `Task.cancel()` method of every task. msg (optional): Passed to the `Task.cancel()` method of every task.
""" """
log.warning("%s cancelling all tasks!", str(self)) log.warning("%s cancelling all tasks!", str(self))
kw = self._get_cancel_kw(msg)
while self._task_groups: while self._task_groups:
group_name, group_reg = self._task_groups.popitem() group_name, group_reg = self._task_groups.popitem()
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
def _pop_ended_meta_tasks(self) -> Set[Task]: def _pop_ended_meta_tasks(self) -> Set[Task]:
""" """
@ -529,9 +550,16 @@ class BaseTaskPool:
self._tasks_ended.clear() self._tasks_ended.clear()
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._closed = True self._closed.set()
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server. async def until_closed(self) -> bool:
"""
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
Returns:
`True` once the pool is closed.
"""
return await self._closed.wait()
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):
@ -604,7 +632,7 @@ class TaskPool(BaseTaskPool):
# This means there was probably something wrong with the function arguments. # This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)", log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs)) str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue continue # TODO: Consider returning instead of continuing
try: try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback, await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) cancel_callback=cancel_callback)
@ -640,7 +668,7 @@ class TaskPool(BaseTaskPool):
kwargs (optional): kwargs (optional):
The keyword-arguments to pass into each function call. The keyword-arguments to pass into each function call.
num (optional): num (optional):
The number of tasks to spawn with the specified parameters. The number of tasks to spawn with the specified parameters. Defaults to 1.
group_name (optional): group_name (optional):
Name of the task group to add the new tasks to. By default, a unique name is constructed in the form Name of the task group to add the new tasks to. By default, a unique name is constructed in the form
:code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an :code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an
@ -941,13 +969,24 @@ class SimpleTaskPool(BaseTaskPool):
async def _start_num(self, num: int, group_name: str) -> None: async def _start_num(self, num: int, group_name: str) -> None:
"""Starts `num` new tasks in group `group_name`.""" """Starts `num` new tasks in group `group_name`."""
start_coroutines = ( for i in range(num):
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name, try:
end_callback=self._end_callback, cancel_callback=self._cancel_callback) coroutine = self._func(*self._args, **self._kwargs)
for _ in range(num) except Exception as e:
) # This means there was probably something wrong with the function arguments.
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling! log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
await gather(*start_coroutines) str(e.__class__.__name__), str(self), self._func.__name__,
repr(self._args), repr(self._kwargs))
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
cancel_callback=self._cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def start(self, num: int) -> str: def start(self, num: int) -> str:
""" """

30
tests/__main__.py Normal file
View File

@ -0,0 +1,30 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Main entry point for all unit tests.
"""
import sys
import unittest
if __name__ == '__main__':
test_suite = unittest.defaultTestLoader.discover('.')
test_runner = unittest.TextTestRunner(resultclass=unittest.TextTestResult)
result = test_runner.run(test_suite)
sys.exit(not result.wasSuccessful())

View File

@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
mock_client = MagicMock(start=mock_client_start) mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client) mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789} mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
self.assertIsNone(await module.main()) self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with() mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs) mock_client_cls.assert_called_once_with(**mock_client_kwargs)

View File

@ -41,9 +41,9 @@ class ControlParserTestCase(TestCase):
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420 self.stream, self.terminal_width = MagicMock(), 420
self.kwargs = { self.kwargs = {
'stream_writer': self.stream_writer, 'stream': self.stream,
'terminal_width': self.terminal_width, 'terminal_width': self.terminal_width,
'formatter_class': FOO 'formatter_class': FOO
} }
@ -72,10 +72,9 @@ class ControlParserTestCase(TestCase):
def test_init(self): def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser) self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer) self.assertEqual(self.stream, self.parser._stream)
self.assertEqual(self.terminal_width, self.parser._terminal_width) self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags) self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands) self.assertIsNone(self.parser._commands)
@ -89,7 +88,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar' expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
to_omit = ['abc', 'xyz'] to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs) output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
@ -107,7 +106,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop' expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop) mock_get_first_doc_line.assert_called_once_with(get_prop)
@ -119,7 +118,7 @@ class ControlParserTestCase(TestCase):
prop = property(get_prop, set_prop) prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property" expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)]) mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
@ -152,8 +151,7 @@ class ControlParserTestCase(TestCase):
mock_subparser = MagicMock(set_defaults=mock_set_defaults) mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x' x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer, common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser} expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x) output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
@ -170,12 +168,12 @@ class ControlParserTestCase(TestCase):
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs) mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self): def test__print_message(self):
self.stream_writer.write = MagicMock() self.stream.write = MagicMock()
self.assertIsNone(self.parser._print_message('')) self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called() self.stream.write.assert_not_called()
msg = 'foo bar baz' msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg)) self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode()) self.stream.write.assert_called_once_with(msg)
@patch.object(parser.ControlParser, '_print_message') @patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock): def test_exit(self, mock__print_message: MagicMock):

View File

@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
import json import json
from argparse import ArgumentError, Namespace from argparse import ArgumentError, Namespace
from io import StringIO
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session from asyncio_taskpool.control import session
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES
from asyncio_taskpool.exceptions import HelpRequested from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool from asyncio_taskpool.pool import SimpleTaskPool
@ -61,14 +62,15 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.mock_reader, self.session._reader) self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer) self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser) self.assertIsNone(self.session._parser)
self.assertIsInstance(self.session._response_buffer, StringIO)
@patch.object(session, 'return_or_exception') @patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock): async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11} test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
mock_return_or_exception.return_value = None mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs)) self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
mock_return_or_exception.assert_awaited_once_with( mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
) )
@ -104,7 +106,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_reader.read = mock_read self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock() self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = { expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer, 'stream': self.session._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: width, CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -132,10 +134,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
kwargs = {FOO: BAR, 'hello': 'python'} kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs)) mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args) self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs) mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
@ -145,7 +146,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs) mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs) mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
@ -161,26 +162,28 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_property_and_respond.reset_mock() mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock() self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = HelpRequested() mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
@ -191,17 +194,23 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating" msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO
self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg) mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode())
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock() self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock() mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock() self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False) self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited() self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited() mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited() self.mock_writer.drain.assert_not_awaited()

View File

@ -18,10 +18,11 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module. Unittests for the `asyncio_taskpool.helpers` module.
""" """
import importlib
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers from asyncio_taskpool.internals import helpers
@ -122,3 +123,45 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
helpers.resolve_dotted_path('foo.bar.baz') helpers.resolve_dotted_path('foo.bar.baz')
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')]) mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
class ClassMethodWorkaroundTestCase(TestCase):
def test_init(self):
def func(): return 'foo'
def getter(): return 'bar'
prop = property(getter)
instance = helpers.ClassMethodWorkaround(func)
self.assertIs(func, instance._getter)
instance = helpers.ClassMethodWorkaround(prop)
self.assertIs(getter, instance._getter)
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
def test_get(self, _mock_init: MagicMock):
def func(x: MagicMock): return x.__name__
instance = helpers.ClassMethodWorkaround(MagicMock())
instance._getter = func
obj, cls = None, MagicMock
expected_output = 'MagicMock'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
obj = MagicMock(__name__='bar')
expected_output = 'bar'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
""" """
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(0, self.task_pool._num_started) self.assertEqual(0, self.task_pool._num_started)
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed) self.assertIsInstance(self.task_pool._closed, Event)
self.assertFalse(self.task_pool._closed.is_set())
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.get_group_ids(group_name, 'something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed.set()
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try: try:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._check_start(awaitable=None, function=mock_coroutine) self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None) self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False self.task_pool._closed.clear()
self.task_pool._locked = True self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked): with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
@ -311,13 +312,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task') @patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock): def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9 task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock() mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO)) self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_group_meta_tasks(self): def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock() mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -336,6 +356,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks') @patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock): def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555 task_id = 555
mock_cancel = MagicMock() mock_cancel = MagicMock()
@ -347,27 +368,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
class MockRegister(set, MagicMock): class MockRegister(set, MagicMock):
pass pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO)) self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(msg=FOO) mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock() self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.cancel_group(BAR) self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called() mock__cancel_and_remove_all_from_group.assert_not_called()
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR)) self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_group_reg = MagicMock() mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg} self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
self.assertIsNone(self.task_pool.cancel_all('msg')) self.assertIsNone(self.task_pool.cancel_all(BAZ))
mock__get_cancel_kw.assert_called_once_with(BAZ)
mock__cancel_and_remove_all_from_group.assert_has_calls([ mock__cancel_and_remove_all_from_group.assert_has_calls([
call(BAR, mock_group_reg, msg='msg'), call(BAR, mock_group_reg, **fake_cancel_kw),
call(FOO, mock_group_reg, msg='msg') call(FOO, mock_group_reg, **fake_cancel_kw)
]) ])
def test__pop_ended_meta_tasks(self): def test__pop_ended_meta_tasks(self):
@ -435,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertTrue(self.task_pool._closed) self.assertTrue(self.task_pool._closed.is_set())
async def test_until_closed(self):
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
output = await self.task_pool.until_closed()
self.assertEqual(FOO, output)
self.task_pool._closed.wait.assert_awaited_once_with()
class TaskPoolTestCase(CommonTestCase): class TaskPoolTestCase(CommonTestCase):
@ -729,13 +762,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
TEST_POOL_CANCEL_CB = MagicMock() TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict: def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | { params = super().get_task_pool_init_params()
params.update({
'func': self.TEST_POOL_FUNC, 'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS, 'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS, 'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB, 'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB, 'cancel_callback': self.TEST_POOL_CANCEL_CB,
} })
return params
def setUp(self) -> None: def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__') self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
@ -762,18 +797,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
@patch.object(pool.SimpleTaskPool, '_start_task') @patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_num(self, mock__start_task: AsyncMock): async def test__start_num(self, mock__start_task: AsyncMock):
fake_coroutine = object()
self.task_pool._func = MagicMock(return_value=fake_coroutine)
num = 3
group_name = FOO + BAR + 'abc' group_name = FOO + BAR + 'abc'
mock_awaitable1, mock_awaitable2 = object(), object()
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
num = 3
self.assertIsNone(await self.task_pool._start_num(num, group_name)) self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(num * [ self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
call(*self.task_pool._args, **self.task_pool._kwargs) call_kw = {
]) 'group_name': group_name,
mock__start_task.assert_has_awaits(num * [ 'end_callback': self.task_pool._end_callback,
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback, 'cancel_callback': self.task_pool._cancel_callback
cancel_callback=self.task_pool._cancel_callback) }
]) mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
self.task_pool._func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock()) @patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())