Compare commits

..

27 Commits

Author SHA1 Message Date
3503c0bf44 removed a few functions from the public API; fixed some docstrings/comments 2022-03-14 19:16:28 +01:00
3d104c979e typo 2022-03-14 18:13:51 +01:00
a92e646411 improved and extended usage/readme docs; small fixes 2022-03-14 18:09:30 +01:00
3d84e1552b version bump 2022-03-13 16:12:04 +01:00
38f4ec1b06 finally reached 100% unittest coverage overall 2022-03-13 16:11:20 +01:00
6f082288d8 brought unittest coverage up to 100% on control modules 2022-03-13 15:44:53 +01:00
9fde231250 moved control-related modules to a sub-package; minor corrections 2022-03-13 15:18:53 +01:00
c72a5035ea big rework of the session-parser-interaction;
dynamically adding pool methods/properties as parser commands;
dynamically executing selected pool method/property;
greatly simplified `ControlSession` class;
removed the need for hard-coded command names;
adjusted unittests accordingly
2022-03-13 14:56:56 +01:00
eb152e4d75 fixed unix server/client tests 2022-03-08 10:22:07 +01:00
d05f84b2c3 additional base commands for control server 2022-03-08 10:15:10 +01:00
7c66604ad0 renamed method get_task_group_ids and extended to accept any number of group names 2022-03-08 09:08:28 +01:00
287906a218 added unittests 2022-03-08 09:05:59 +01:00
ce0f9a1f65 version bump 2022-02-25 22:44:25 +01:00
5dad4ab0c7 implemented TCP socket control; switched example to TCP 2022-02-25 22:42:37 +01:00
ae6bb1bd17 skipping unix socket tests on Windows 2022-02-25 21:17:14 +01:00
e501a849f3 clarifications and corrections 2022-02-25 19:57:54 +01:00
ed6badb088 moved imports for unix socket connections to init methods of client and server 2022-02-25 19:09:28 +01:00
c63f079da4 huge update:
introduced meta tasks which are used by `_map()`;
introduced task groups;
ending with `gather_and_close()` now;
pool unittests rewritten accordingly;
two new helper classes
2022-02-24 19:16:24 +01:00
4994135062 renamed num_tasks in the map-methods to group_size; reworded/extended the docstrings 2022-02-19 16:02:50 +01:00
d0c0177681 full unit test coverage and docstrings for client module; minor refactoring 2022-02-19 12:56:08 +01:00
bac7b32342 full unit test coverage and docstrings for session_parser module; minor changes 2022-02-14 22:11:31 +01:00
96d01e7259 full unit test coverage and docstrings for session module 2022-02-14 17:59:11 +01:00
3f3eb7ce38 minor rephrasing 2022-02-13 21:12:35 +01:00
05d51eface simplified method 2022-02-13 20:58:36 +01:00
b6aed727e9 additional unit tests 2022-02-13 19:55:27 +01:00
c9a8d9ecd1 better placeholders 2022-02-13 19:45:53 +01:00
538b9cc91c major refactoring of the control session/parser classes; restructured constants 2022-02-13 19:39:21 +01:00
32 changed files with 3134 additions and 1173 deletions

View File

@ -5,8 +5,10 @@ omit =
.venv/*
[report]
fail_under = 100
show_missing = True
skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
omit =
tests/*

View File

@ -2,9 +2,18 @@
**Dynamically manage pools of asyncio tasks**
## Contents
- [Contents](#contents)
- [Summary](#summary)
- [Usage](#usage)
- [Installation](#installation)
- [Dependencies](#dependencies)
- [Testing](#testing)
- [License](#license)
## Summary
A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
A **task pool** is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
With an interface that is intentionally similar to the [`multiprocessing.Pool`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool) class from the standard library, the `TaskPool` provides you such methods as `apply`, `map`, and `starmap` to execute coroutines concurrently as [`asyncio.Task`](https://docs.python.org/3/library/asyncio-task.html#task-object) objects. There is no limitation imposed on what kind of tasks can be run or in what combination, when new ones can be added, or when they can be cancelled.
@ -14,12 +23,20 @@ If you need control over a task pool at runtime, you can launch an asynchronous
## Usage
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like:
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
```python
from asyncio_taskpool import SimpleTaskPool
...
async def work(foo, bar): ...
async def work(_foo, _bar): ...
...
async def main():
pool = SimpleTaskPool(work, args=('xyz', 420))
await pool.start(5)
@ -27,11 +44,11 @@ async def main():
pool.stop(3)
...
pool.lock()
await pool.gather()
await pool.gather_and_close()
...
```
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md).
@ -47,7 +64,7 @@ Python Version 3.8+, tested on Linux
## Testing
Install `asyncio-taskpool[dev]` dependencies or just manually install `coverage` with `pip`.
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
## License
@ -56,6 +73,6 @@ Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests an
The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COPYING.LESSER) are included in this repository. If not, see https://www.gnu.org/licenses/.
## Copyright
---
© 2022 Daniil Fajnberg

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.3.1
version = 0.6.4
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -19,5 +19,5 @@ Brings the main classes up to package level for import convenience.
"""
from .control.server import TCPControlServer, UnixControlServer
from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer

View File

@ -1,67 +0,0 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
CLI client entry point.
"""
import sys
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Dict, Any
from .client import ControlClient, UnixControlClient
from .constants import PACKAGE_NAME
from .pool import TaskPool
from .server import ControlServer
CONN_TYPE = 'conn_type'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
def parse_cli() -> Dict[str, Any]:
parser = ArgumentParser(
prog=PACKAGE_NAME,
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
)
return vars(parser.parse_args())
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class
client = UnixControlClient(path=kwargs[SOCKET_PATH])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
if __name__ == '__main__':
run(main())

View File

@ -1,95 +0,0 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Classes of control clients for a simply interface to a task pool control server.
"""
import json
import shutil
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from asyncio_taskpool import constants
from asyncio_taskpool.types import ClientConnT
class ControlClient(ABC):
@abstractmethod
async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError
@staticmethod
def client_info() -> dict:
return {'width': shutil.get_terminal_size().columns}
def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try:
msg = input("> ").strip().lower()
except EOFError:
msg = constants.CLIENT_EXIT
except KeyboardInterrupt:
print()
return
if msg == constants.CLIENT_EXIT:
writer.close()
self._connected = False
return
try:
writer.write(msg.encode())
await writer.drain()
except ConnectionError as e:
self._connected = False
print(e, file=sys.stderr)
return
print((await reader.read(constants.MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
await self._server_handshake(reader, writer)
while self._connected:
await self._interact(reader, writer)
print("Disconnected from control server.")
class UnixControlClient(ControlClient):
def __init__(self, **conn_kwargs) -> None:
self._socket_path = Path(conn_kwargs.pop('path'))
super().__init__(**conn_kwargs)
async def open_connection(self, **kwargs) -> ClientConnT:
try:
return await open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr)
return None, None

View File

@ -20,13 +20,19 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024000
CMD = 'command'
CMD_NAME = 'name'
CMD_POOL_SIZE = 'pool-size'
CMD_NUM_RUNNING = 'num-running'
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop-all'
CMD_FUNC_NAME = 'func-name'
DEFAULT_TASK_GROUP = ''
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer'
CMD = 'command'
CMD_OK = b"ok"
class CLIENT_INFO:
__slots__ = ()
TERMINAL_WIDTH = 'terminal_width'

View File

View File

@ -0,0 +1,77 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
CLI client entry point.
"""
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Any, Dict, Sequence
from ..constants import PACKAGE_NAME
from ..pool import TaskPool
from .client import ControlClient, TCPControlClient, UnixControlClient
from .server import TCPControlServer, UnixControlServer
CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port'
def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
parser = ArgumentParser(
prog=f'{PACKAGE_NAME}.control',
description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types")
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument(
HOST,
help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
tcp_parser.add_argument(
PORT,
type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient})
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient})
return vars(parser.parse_args(args))
async def main():
kwargs = parse_cli()
client_cls = kwargs.pop(CLIENT_CLASS)
await client_cls(**kwargs).start()
if __name__ == '__main__':
run(main())

View File

@ -0,0 +1,192 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Classes of control clients for a simply interface to a task pool control server.
"""
import json
import shutil
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_connection
from pathlib import Path
from typing import Optional, Union
from ..constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from ..types import ClientConnT, PathT
class ControlClient(ABC):
"""
Abstract base class for a simple implementation of a task pool control client.
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
This is a minimal working implementation.
"""
@staticmethod
def _client_info() -> dict:
"""Returns a dictionary of client information relevant for the handshake with the server."""
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
@abstractmethod
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs`
(unpacked) as keyword-arguments.
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
`None` and `None`, if it failed to establish the defined connection.
"""
raise NotImplementedError
def __init__(self, **conn_kwargs) -> None:
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Performs the first interaction with the server providing it with the necessary client information.
Upon completion, the server's info is printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
self._connected = True
writer.write(json.dumps(self._client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
print("Type '-h' to get help and usage instructions for all available commands.\n")
def _get_command(self, writer: StreamWriter) -> Optional[str]:
"""
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
Args:
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
"""
try:
msg = input("> ").strip().lower()
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
msg = CLIENT_EXIT
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print()
return
if msg == CLIENT_EXIT:
writer.close()
self._connected = False
return
return msg
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
cmd = self._get_command(writer)
if cmd is None:
return
try:
# Send the command to the server.
writer.write(cmd.encode())
await writer.drain()
except ConnectionError as e:
self._connected = False
print(e, file=sys.stderr)
return
# Await the server's response, then print it.
print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self) -> None:
"""
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
If the connection can not be established, an error message is printed to `stderr` and the method returns.
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
disconnected-message.
"""
reader, writer = await self._open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
await self._server_handshake(reader, writer)
while self._connected:
await self._interact(reader, writer)
print("Disconnected from control server.")
class TCPControlClient(ControlClient):
"""Task pool control client that expects a TCP socket to be exposed by the control server."""
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments."""
self._host = host
self._port = port
super().__init__(**conn_kwargs)
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_connection` function.
Returns a tuple of `None` and `None`, if the connection can not be established;
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await open_connection(self._host, self._port, **kwargs)
except ConnectionError as e:
print(str(e), file=sys.stderr)
return None, None
class UnixControlClient(ControlClient):
"""Task pool control client that expects a unix socket to be exposed by the control server."""
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
"""In addition to what the base class does, the `socket_path` is expected as a non-optional argument."""
from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path)
super().__init__(**conn_kwargs)
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_unix_connection` function.
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await self._open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr)
return None, None

View File

@ -0,0 +1,302 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the `ControlParser` class used by a control server.
"""
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature
from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Set, Type, TypeVar
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..exceptions import HelpRequested, ParserError
from ..helpers import get_first_doc_line
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
ParsersDict = Dict[str, 'ControlParser']
OMIT_PARAMS_DEFAULT = ('self', )
NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description'
class ControlParser(ArgumentParser):
"""
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client.
Finally, it offers some convenience methods and makes use of custom exceptions.
"""
@staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
"""
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
as convenient, when making use of sub-parsers.
Args:
terminal_width:
The number of columns of the terminal to which to adjust help formatting.
base_cls (optional):
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
Returns:
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
"""
if base_cls is None:
base_cls = ArgumentDefaultsHelpFormatter
class ClientHelpFormatter(base_cls):
def __init__(self, *args, **kwargs) -> None:
kwargs['width'] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None,
**kwargs) -> None:
"""
Subclass of the `ArgumentParser` geared towards asynchronous interaction with an object "from the outside".
Allows directing output to a specified writer rather than stdout/stderr and setting terminal width explicitly.
Args:
stream_writer:
The instance of the `asyncio.StreamWriter` to use for message output.
terminal_width (optional):
The terminal width to use for all message formatting. Defaults to `shutil.get_terminal_size().columns`.
**kwargs(optional):
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the `help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
"""
self._stream_writer: StreamWriter = stream_writer
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs)
self._flags: Set[str] = set()
self._commands = None
def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT,
**subparser_kwargs) -> 'ControlParser':
"""
Takes a function along with its parameters and adds a corresponding (sub-)command to the parser.
The `add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any public method of the `SimpleTaskPool` class.
Args:
function:
The reference to the function to be "converted" to a parser command.
omit_params (optional):
Names of function parameters not to add as parser arguments.
**subparser_kwargs (optional):
Passed directly to the `add_parser` method.
Returns:
The subparser instance created from the function.
"""
subparser_kwargs.setdefault(NAME, function.__name__.replace('_', '-'))
subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME])
subparser_kwargs.setdefault(HELP, get_first_doc_line(function))
subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP])
subparser: ControlParser = self._commands.add_parser(**subparser_kwargs)
subparser.add_function_args(function, omit_params)
return subparser
def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser':
"""
Same as the `add_function_command` method, but for properties.
Args:
prop:
The reference to the property to be "converted" to a parser command.
cls_name (optional):
Name of the class the property is defined on to appear in the command help text.
**subparser_kwargs (optional):
Passed directly to the `add_parser` method.
Returns:
The subparser instance created from the property.
"""
subparser_kwargs.setdefault(NAME, prop.fget.__name__.replace('_', '-'))
subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME])
getter_help = get_first_doc_line(prop.fget)
if prop.fset is None:
subparser_kwargs.setdefault(HELP, getter_help)
else:
subparser_kwargs.setdefault(HELP, f"Get/set the `{cls_name}.{subparser_kwargs[NAME]}` property")
subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP])
subparser: ControlParser = self._commands.add_parser(**subparser_kwargs)
if prop.fset is not None:
_, param = signature(prop.fset).parameters.values()
setter_arg_help = f"If provided: {get_first_doc_line(prop.fset)} If omitted: {getter_help}"
subparser.add_function_arg(param, nargs='?', default=SUPPRESS, help=setter_arg_help)
return subparser
def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (),
member_arg_name: str = CMD) -> ParsersDict:
"""
Takes a class and adds its methods and properties as (sub-)commands to the parser.
The `add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments.
This method works correctly with the `SimpleTaskPool` class.
Args:
cls:
The reference to the class whose methods/properties are to be "converted" to parser commands.
public_only (optional):
If `False`, protected and private members are considered as well. `True` by default.
omit_members (optional):
Names of functions/properties not to add as parser commands.
member_arg_name (optional):
After parsing the arguments, depending on which command was invoked by the user, the corresponding
method/property will be stored as an extra argument in the parsed namespace under this attribute name.
Defaults to `constants.CMD`.
Returns:
Dictionary mapping class member names to the (sub-)parsers created from them.
"""
parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only):
continue
if isfunction(member):
subparser = self.add_function_command(member, **common_kwargs)
elif isinstance(member, property):
subparser = self.add_property_command(member, cls.__name__, **common_kwargs)
else:
continue
subparser.set_defaults(**{member_arg_name: member})
parsers[name] = subparser
return parsers
def add_subparsers(self, *args, **kwargs):
"""Adds the subparsers action as an internal attribute before returning it."""
self._commands = super().add_subparsers(*args, **kwargs)
return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
if message:
self._stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked."""
if message:
self._print_message(message)
def error(self, message: str) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().error(message=message)
raise ParserError
def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().print_help(file)
raise HelpRequested
def add_function_arg(self, parameter: Parameter, **kwargs) -> Action:
"""
Takes an `inspect.Parameter` of a function and adds a corresponding argument to the parser.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any parameter of any public method of the `SimpleTaskPool` class.
Args:
parameter: The `inspect.Parameter` object to be converted to a parser argument.
**kwargs: Passed to the `add_argument` method of the base class.
Returns:
The `argparse.Action` returned by the `add_argument` method.
"""
if parameter.default is Parameter.empty:
# A non-optional function parameter should correspond to a positional argument.
name_or_flags = [parameter.name]
else:
flag = None
long = f'--{parameter.name.replace("_", "-")}'
# We try to generate a short version (flag) for the argument.
letter = parameter.name[0]
if letter not in self._flags:
flag = f'-{letter}'
self._flags.add(letter)
elif letter.upper() not in self._flags:
flag = f'-{letter.upper()}'
self._flags.add(letter.upper())
name_or_flags = [long] if flag is None else [flag, long]
if parameter.annotation is bool:
# If we are dealing with a boolean parameter, always use the 'store_true' action.
# Even if the parameter's default value is `True`, this will make the parser argument's default `False`.
kwargs.setdefault('action', 'store_true')
else:
# For now, any other type annotation will implicitly use the default action 'store'.
# In addition, we always set the default value.
kwargs.setdefault('default', parameter.default)
if parameter.kind == Parameter.VAR_POSITIONAL:
# This is to be able to later unpack an arbitrary number of positional arguments.
kwargs.setdefault('nargs', '*')
if not kwargs.get('action') == 'store_true':
# Set the type from the parameter annotation.
kwargs.setdefault('type', _get_arg_type_wrapper(parameter.annotation))
return self.add_argument(*name_or_flags, **kwargs)
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
"""
Takes a function reference and adds its parameters as arguments to the parser.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any public method of the `SimpleTaskPool` class.
Args:
function:
The function whose parameters are to be converted to parser arguments.
Its parameters must be properly annotated.
omit (optional):
Names of function parameters not to add as parser arguments.
"""
for param in signature(function).parameters.values():
if param.name not in omit:
# TODO: Look into parsing docstrings properly to try and extract argument help text.
# For now, the argument help just shows the type it will be converted to.
self.add_function_arg(param, help=repr(param.annotation))
def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
"""
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
See: https://bugs.python.org/issue36078
"""
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg)
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
wrapper.__name__ = cls.__name__
return wrapper

View File

@ -23,15 +23,15 @@ import logging
from abc import ABC, abstractmethod
from asyncio import AbstractServer
from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.streams import StreamReader, StreamWriter, start_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Optional, Union
from .client import ControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from ..pool import TaskPool, SimpleTaskPool
from ..types import ConnectedCallbackT
from .client import ControlClient, TCPControlClient, UnixControlClient
from .session import ControlSession
from .types import ConnectedCallbackT
log = logging.getLogger(__name__)
@ -125,6 +125,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
async def serve_forever(self) -> Task:
"""
This method actually starts the server and begins listening to client connections on the specified interface.
It should never block because the serving will be performed in a separate task.
"""
log.debug("Starting %s...", self.__class__.__name__)
@ -132,16 +133,36 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
return create_task(self._serve_forever())
class TCPControlServer(ControlServer):
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
_client_class = TCPControlClient
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
self._host = server_kwargs.pop('host')
self._port = server_kwargs.pop('port')
super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_server(client_connected_cb, self._host, self._port, **kwargs)
log.debug("Opened socket at %s:%s", self._host, self._port)
return server
def _final_callback(self) -> None:
log.debug("Closed socket at %s:%s", self._host, self._port)
class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
from asyncio.streams import start_unix_server
self._start_unix_server = start_unix_server
self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path))
return server

View File

@ -0,0 +1,185 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the `ControlSession` class used by the control server.
"""
import logging
import json
from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature
from typing import Callable, Optional, Union, TYPE_CHECKING
from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from ..exceptions import CommandError, HelpRequested, ParserError
from ..helpers import return_or_exception
from ..pool import TaskPool, SimpleTaskPool
from .parser import ControlParser
if TYPE_CHECKING:
from .server import ControlServer
log = logging.getLogger(__name__)
class ControlSession:
"""
This class defines the API for controlling a task pool instance from the outside.
The commands received from a connected client are translated into method calls on the task pool instance.
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
"""
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
"""
Instantiation should happen once a client connection to the control server has already been established.
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
Args:
server:
The instance of a `ControlServer` subclass starting the session.
reader:
The `asyncio.StreamReader` created when a client connected to the server.
writer:
The `asyncio.StreamWriter` created when a client connected to the server.
"""
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[ControlParser] = None
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
"""
Takes a pool method reference, executes it, and writes a response accordingly.
If the first parameter is named `self`, the method will be called with the `_pool` instance as its first
positional argument. If it returns nothing, the response upon successful execution will be `constants.CMD_OK`,
otherwise the response written to the stream will be its return value (as an encoded string).
Args:
prop:
The reference to the method defined on the `_pool` instance's class.
**kwargs (optional):
Must correspond to the arguments expected by the `method`.
Correctly unpacks arbitrary-length positional and keyword-arguments.
"""
log.debug("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
normal_pos, var_pos = [], []
for param in signature(method).parameters.values():
if param.name == 'self':
normal_pos.append(self._pool)
elif param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
normal_pos.append(kwargs.pop(param.name))
elif param.kind == param.VAR_POSITIONAL:
var_pos = kwargs.pop(param.name)
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
self._writer.write(CMD_OK if output is None else str(output).encode())
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
"""
Takes a pool property reference, executes its setter or getter, and writes a response accordingly.
The property set/get method will always be called with the `_pool` instance as its first positional argument.
Args:
prop:
The reference to the property defined on the `_pool` instance's class.
**kwargs (optional):
If not empty, the property setter is executed and the keyword arguments are passed along to it; the
response upon successful execution will be `constants.CMD_OK`. Otherwise the property getter is
executed and the response written to the stream will be its return value (as an encoded string).
"""
if kwargs:
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK)
else:
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
async def client_handshake(self) -> None:
"""
This method must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `ControlParser` is initialized and configured.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
parser_kwargs = {
STREAM_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '',
'usage': f'[-h] [{CMD}] ...'
}
self._parser = ControlParser(**parser_kwargs)
self._parser.add_subparsers(title="Commands",
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
self._parser.add_class_commands(self._pool.__class__)
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _parse_command(self, msg: str) -> None:
"""
Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
`ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the entire
dictionary of keyword-arguments returned by the `ControlParser` passed into it.
Args:
msg: The non-empty string read from the client stream.
"""
try:
kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode())
return
except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name)
return
command = kwargs.pop(CMD)
if isfunction(command):
await self._exec_method_and_respond(command, **kwargs)
elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs)
else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode())
async def listen(self) -> None:
"""
Enters the main control loop that only ends if either the server or the client disconnect.
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
This method should be called, when the client connection was established and the handshake was successful.
It will obviously block indefinitely.
"""
while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
await self._writer.drain()

View File

@ -23,6 +23,10 @@ class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
pass
class PoolIsLocked(PoolException):
pass
@ -43,6 +47,10 @@ class InvalidTaskID(PoolException):
pass
class InvalidGroupName(PoolException):
pass
class PoolStillUnlocked(PoolException):
pass
@ -57,3 +65,11 @@ class ServerException(Exception):
class HelpRequested(ServerException):
pass
class ParserError(ServerException):
pass
class CommandError(ServerException):
pass

View File

@ -0,0 +1,75 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the definition of the `TaskGroupRegister` class.
"""
from asyncio.locks import Lock
from collections.abc import MutableSet
from typing import Iterator, Set
class TaskGroupRegister(MutableSet):
"""
This class combines the interface of a regular `set` with that of the `asyncio.Lock`.
It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for
preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of
tasks, as well as before starting a task within the group.
"""
def __init__(self, *task_ids: int) -> None:
self._ids: Set[int] = set(task_ids)
self._lock = Lock()
def __contains__(self, task_id: int) -> bool:
"""Abstract method for the `MutableSet` base class."""
return task_id in self._ids
def __iter__(self) -> Iterator[int]:
"""Abstract method for the `MutableSet` base class."""
return iter(self._ids)
def __len__(self) -> int:
"""Abstract method for the `MutableSet` base class."""
return len(self._ids)
def add(self, task_id: int) -> None:
"""Abstract method for the `MutableSet` base class."""
self._ids.add(task_id)
def discard(self, task_id: int) -> None:
"""Abstract method for the `MutableSet` base class."""
self._ids.discard(task_id)
async def acquire(self) -> bool:
"""Wrapper around the lock's `acquire()` method."""
return await self._lock.acquire()
def release(self) -> None:
"""Wrapper around the lock's `release()` method."""
self._lock.release()
async def __aenter__(self) -> None:
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
await self._lock.acquire()
return None
async def __aexit__(self, exc_type, exc, tb) -> None:
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
self._lock.release()

View File

@ -15,11 +15,10 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Miscellaneous helper functions.
Miscellaneous helper functions. None of these should be considered part of the public API.
"""
import re
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from inspect import getdoc
@ -52,12 +51,8 @@ async def join_queue(q: Queue) -> None:
await q.join()
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0]
return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,58 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the definition of an `asyncio.Queue` subclass.
"""
from asyncio.queues import Queue as _Queue
from typing import Any
class Queue(_Queue):
"""This just adds a little syntactic sugar to the `asyncio.Queue`."""
def item_processed(self) -> None:
"""
Does exactly the same as `task_done()`.
This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing,
invites confusion, and immensely reduces readability (in the context of this library). And readability counts.
"""
self.task_done()
async def __aenter__(self) -> Any:
"""
Implements an asynchronous context manager for the queue.
Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned.
It allows writing code this way:
>>> queue = Queue()
>>> ...
>>> async with queue as item:
>>> ...
"""
return await self.get()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""
Implements an asynchronous context manager for the queue.
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
but in some situations it makes the code much cleaner.
"""
self.item_processed()

View File

@ -1,218 +0,0 @@
import logging
import json
from argparse import ArgumentError, ArgumentParser, HelpFormatter, Namespace
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
from . import constants
from .exceptions import HelpRequested
from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool
if TYPE_CHECKING:
from .server import ControlServer
log = logging.getLogger(__name__)
NUM = 'num'
WIDTH = 'width'
class CommandParser(ArgumentParser):
@staticmethod
def help_formatter_factory(terminal_width: int) -> Type[HelpFormatter]:
class ClientHelpFormatter(HelpFormatter):
def __init__(self, *args, **kwargs) -> None:
kwargs[WIDTH] = terminal_width
super().__init__(*args, **kwargs)
return ClientHelpFormatter
def __init__(self, *args, **kwargs) -> None:
parent: CommandParser = kwargs.pop('parent', None)
self._stream_writer: StreamWriter = parent.stream_writer if parent else kwargs.pop('writer')
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(WIDTH)
kwargs.setdefault('formatter_class', self.help_formatter_factory(self._terminal_width))
kwargs.setdefault('exit_on_error', False)
super().__init__(*args, **kwargs)
@property
def stream_writer(self) -> StreamWriter:
return self._stream_writer
@property
def terminal_width(self) -> int:
return self._terminal_width
def _print_message(self, message: str, *args, **kwargs) -> None:
if message:
self.stream_writer.write(message.encode())
def exit(self, status: int = 0, message: str = None) -> None:
if message:
self._print_message(message)
def print_help(self, file=None) -> None:
super().print_help(file)
raise HelpRequested
class ControlSession:
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._parser: Optional[CommandParser] = None
def _add_base_commands(self):
subparsers = self._parser.add_subparsers(title="Commands", dest=constants.CMD)
subparsers.add_parser(
constants.CMD_NAME,
prog=constants.CMD_NAME,
help=get_first_doc_line(self._pool.__class__.__str__),
parent=self._parser,
)
subparser_pool_size = subparsers.add_parser(
constants.CMD_POOL_SIZE,
prog=constants.CMD_POOL_SIZE,
help="Get/set the maximum number of tasks in the pool",
parent=self._parser,
)
subparser_pool_size.add_argument(
NUM,
nargs='?',
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
)
subparsers.add_parser(
constants.CMD_NUM_RUNNING,
help=get_first_doc_line(self._pool.__class__.num_running.fget),
parent=self._parser,
)
return subparsers
def _add_simple_commands(self):
subparsers = self._add_base_commands()
subparser = subparsers.add_parser(
constants.CMD_START,
prog=constants.CMD_START,
help=get_first_doc_line(self._pool.__class__.start),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to start. Defaults to 1."
)
subparser = subparsers.add_parser(
constants.CMD_STOP,
prog=constants.CMD_STOP,
help=get_first_doc_line(self._pool.__class__.stop),
parent=self._parser,
)
subparser.add_argument(
NUM,
nargs='?',
type=int,
default=1,
help="Number of tasks to stop. Defaults to 1."
)
subparsers.add_parser(
constants.CMD_STOP_ALL,
prog=constants.CMD_STOP_ALL,
help=get_first_doc_line(self._pool.__class__.stop_all),
parent=self._parser,
)
subparsers.add_parser(
constants.CMD_FUNC_NAME,
prog=constants.CMD_FUNC_NAME,
help=get_first_doc_line(self._pool.__class__.func_name.fget),
parent=self._parser,
)
def _init_parser(self, client_terminal_width: int) -> None:
self._parser = CommandParser(prog='', writer=self._writer, width=client_terminal_width)
if isinstance(self._pool, TaskPool):
pass # TODO
elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands()
async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(constants.MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[WIDTH])
self._writer.write(str(self._pool).encode())
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
else:
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, int(num))
async def _cmd_num_running(self, **_kwargs) -> None:
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None:
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, args: Namespace) -> None:
args = vars(args)
cmd: str = args.pop(constants.CMD, None)
if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
await method(**args)
async def _parse_command(self, msg: str) -> None:
try:
args, argv = self._parser.parse_known_args(msg.split(' '))
except ArgumentError as e:
self._writer.write(str(e).encode())
return
except HelpRequested:
return
if argv:
log.debug("%s sent unknown arguments: %s", self._client_class_name, msg)
self._writer.write(b"Invalid command!")
return
await self._execute_command(args)
async def listen(self) -> None:
while self._control_server.is_serving():
msg = (await self._reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
await self._writer.drain()

View File

@ -20,6 +20,7 @@ Custom type definitions used in various modules.
from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
@ -31,8 +32,10 @@ KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable
CancelCallbackT = Callable
EndCB = Callable
CancelCB = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
PathT = Union[Path, str]

View File

View File

@ -0,0 +1,45 @@
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient
from asyncio_taskpool.control import __main__ as module
class CLITestCase(IsolatedAsyncioTestCase):
def test_parse_cli(self):
socket_path = '/some/path/to.sock'
args = [module.UNIX, socket_path]
expected_kwargs = {
module.CLIENT_CLASS: UnixControlClient,
module.SOCKET_PATH: Path(socket_path)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
host, port = '1.2.3.4', '1234'
args = [module.TCP, host, port]
expected_kwargs = {
module.CLIENT_CLASS: TCPControlClient,
module.HOST: host,
module.PORT: int(port)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
with patch('sys.stderr'):
with self.assertRaises(SystemExit):
module.parse_cli(['invalid', 'foo', 'bar'])
@patch.object(module, 'parse_cli')
async def test_main(self, mock_parse_cli: MagicMock):
mock_client_start = AsyncMock()
mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
mock_client_start.assert_awaited_once_with()

View File

@ -0,0 +1,249 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.client` module.
"""
import json
import os
import shutil
import sys
from pathlib import Path
from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, call, patch
from asyncio_taskpool.control import client
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
FOO, BAR = 'foo', 'bar'
class ControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.control.client.ControlClient.__abstractmethods__', set())
self.print_patcher = patch.object(client, 'print')
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_print = self.print_patcher.start()
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.ControlClient(**self.kwargs)
self.mock_read = AsyncMock(return_value=FOO.encode())
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
self.mock_reader = MagicMock(read=self.mock_read)
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
def tearDown(self) -> None:
self.abstract_patcher.stop()
self.print_patcher.stop()
def test_client_info(self):
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
client.ControlClient._client_info())
async def test_abstract(self):
with self.assertRaises(NotImplementedError):
await self.client._open_connection(**self.kwargs)
def test_init(self):
self.assertEqual(self.kwargs, self.client._conn_kwargs)
self.assertFalse(self.client._connected)
@patch.object(client.ControlClient, '_client_info')
async def test__server_handshake(self, mock__client_info: MagicMock):
mock__client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected)
mock__client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_has_calls([
call("Connected to", self.mock_read.return_value.decode()),
call("Type '-h' to get help and usage instructions for all available commands.\n")
])
@patch.object(client, 'input')
def test__get_command(self, mock_input: MagicMock):
self.client._connected = True
mock_input.return_value = ' ' + FOO.upper() + ' '
mock_close = MagicMock()
mock_writer = MagicMock(close=mock_close)
output = self.client._get_command(mock_writer)
self.assertEqual(FOO, output)
mock_input.assert_called_once()
mock_close.assert_not_called()
self.assertTrue(self.client._connected)
mock_input.reset_mock()
mock_input.side_effect = KeyboardInterrupt
self.assertIsNone(self.client._get_command(mock_writer))
mock_input.assert_called_once()
mock_close.assert_not_called()
self.assertTrue(self.client._connected)
mock_input.reset_mock()
mock_input.side_effect = EOFError
self.assertIsNone(self.client._get_command(mock_writer))
mock_input.assert_called_once()
mock_close.assert_called_once()
self.assertFalse(self.client._connected)
@patch.object(client.ControlClient, '_get_command')
async def test__interact(self, mock__get_command: MagicMock):
self.client._connected = True
mock__get_command.return_value = None
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_not_called()
self.mock_drain.assert_not_awaited()
self.mock_read.assert_not_awaited()
self.mock_print.assert_not_called()
self.assertTrue(self.client._connected)
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
self.mock_drain.side_effect = err = ConnectionError()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_not_awaited()
self.mock_print.assert_called_once_with(err, file=sys.stderr)
self.assertFalse(self.client._connected)
self.client._connected = True
self.mock_write.reset_mock()
self.mock_drain.reset_mock(side_effect=True)
self.mock_print.reset_mock()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with(FOO)
self.assertTrue(self.client._connected)
@patch.object(client.ControlClient, '_interact')
@patch.object(client.ControlClient, '_server_handshake')
@patch.object(client.ControlClient, '_open_connection')
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
mock__interact: AsyncMock):
mock__open_connection.return_value = None, None
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_not_awaited()
mock__interact.assert_not_awaited()
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
mock__open_connection.reset_mock()
self.mock_print.reset_mock()
mock__open_connection.return_value = self.mock_reader, self.mock_writer
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
mock__interact.assert_not_awaited()
self.mock_print.assert_called_once_with("Disconnected from control server.")
mock__open_connection.reset_mock()
mock__server_handshake.reset_mock()
self.mock_print.reset_mock()
self.client._connected = True
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
mock__interact.side_effect = disconnect
self.assertIsNone(await self.client.start())
mock__open_connection.assert_awaited_once_with(**self.kwargs)
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
self.mock_print.assert_called_once_with("Disconnected from control server.")
class TCPControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.host, self.port = 'localhost', 12345
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.TCPControlClient(host=self.host, port=self.port, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test_init(self):
self.assertEqual(self.host, self.client._host)
self.assertEqual(self.port, self.client._port)
self.mock_base_init.assert_called_once_with(**self.kwargs)
@patch.object(client, 'print')
@patch.object(client, 'open_connection')
async def test__open_connection(self, mock_open_connection: AsyncMock, mock_print: MagicMock):
mock_open_connection.return_value = expected_output = 'something'
kwargs = {'a': 1, 'b': 2}
output = await self.client._open_connection(**kwargs)
self.assertEqual(expected_output, output)
mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs)
mock_print.assert_not_called()
mock_open_connection.reset_mock()
mock_open_connection.side_effect = e = ConnectionError()
output1, output2 = await self.client._open_connection(**kwargs)
self.assertIsNone(output1)
self.assertIsNone(output2)
mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs)
mock_print.assert_called_once_with(str(e), file=sys.stderr)
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.path = '/tmp/asyncio_taskpool'
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test_init(self):
self.assertEqual(Path(self.path), self.client._socket_path)
self.mock_base_init.assert_called_once_with(**self.kwargs)
@patch.object(client, 'print')
async def test__open_connection(self, mock_print: MagicMock):
expected_output = 'something'
self.client._open_unix_connection = mock_open_unix_connection = AsyncMock(return_value=expected_output)
kwargs = {'a': 1, 'b': 2}
output = await self.client._open_connection(**kwargs)
self.assertEqual(expected_output, output)
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
mock_print.assert_not_called()
mock_open_unix_connection.reset_mock()
mock_open_unix_connection.side_effect = FileNotFoundError
output1, output2 = await self.client._open_connection(**kwargs)
self.assertIsNone(output1)
self.assertIsNone(output2)
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)

View File

@ -0,0 +1,268 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.control.parser` module.
"""
from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS
from inspect import signature
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
from asyncio_taskpool.control import parser
from asyncio_taskpool.exceptions import HelpRequested, ParserError
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(TestCase):
def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420
self.kwargs = {
'stream_writer': self.stream_writer,
'terminal_width': self.terminal_width,
'formatter_class': FOO
}
self.parser = parser.ControlParser(**self.kwargs)
def tearDown(self) -> None:
self.help_formatter_factory_patcher.stop()
def test_help_formatter_factory(self):
self.help_formatter_factory_patcher.stop()
class MockBaseClass(HelpFormatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
terminal_width = 123456789
cls = parser.ControlParser.help_formatter_factory(terminal_width, MockBaseClass)
self.assertTrue(issubclass(cls, MockBaseClass))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
cls = parser.ControlParser.help_formatter_factory(terminal_width)
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
instance = cls('prog')
self.assertEqual(terminal_width, getattr(instance, '_width'))
def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer)
self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands)
@patch.object(parser, 'get_first_doc_line')
def test_add_function_command(self, mock_get_first_doc_line: MagicMock):
def foo_bar(): pass
mock_subparser = MagicMock()
mock_add_parser = MagicMock(return_value=mock_subparser)
self.parser._commands = MagicMock(add_parser=mock_add_parser)
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output)
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_args.assert_called_once_with(foo_bar, to_omit)
@patch.object(parser, 'get_first_doc_line')
def test_add_property_command(self, mock_get_first_doc_line: MagicMock):
def get_prop(_self): pass
def set_prop(_self, _value): pass
prop = property(get_prop)
mock_subparser = MagicMock()
mock_add_parser = MagicMock(return_value=mock_subparser)
self.parser._commands = MagicMock(add_parser=mock_add_parser)
mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop)
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_arg.assert_not_called()
mock_get_first_doc_line.reset_mock()
mock_add_parser.reset_mock()
prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs
output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
mock_add_parser.assert_called_once_with(**expected_kwargs)
mock_subparser.add_function_arg.assert_called_once_with(
tuple(signature(set_prop).parameters.values())[1],
nargs='?',
default=SUPPRESS,
help=f"If provided: {mock_help} If omitted: {mock_help}"
)
@patch.object(parser.ControlParser, 'add_property_command')
@patch.object(parser.ControlParser, 'add_function_command')
def test_add_class_commands(self, mock_add_function_command: MagicMock, mock_add_property_command: MagicMock):
class FooBar:
some_attribute = None
def _protected(self, _): pass
def __private(self, _): pass
def to_omit(self, _): pass
def method(self, _): pass
@property
def prop(self): return None
mock_set_defaults = MagicMock()
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer,
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output)
mock_add_function_command.assert_called_once_with(FooBar.method, **common_kwargs)
mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs)
mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})])
@patch.object(parser.ArgumentParser, 'add_subparsers')
def test_add_subparsers(self, mock_base_add_subparsers: MagicMock):
args, kwargs = [1, 2, 42], {FOO: 123, BAR: 456}
mock_base_add_subparsers.return_value = mock_action = MagicMock()
output = self.parser.add_subparsers(*args, **kwargs)
self.assertEqual(mock_action, output)
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self):
self.stream_writer.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode())
@patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock):
self.assertIsNone(self.parser.exit(123, ''))
mock__print_message.assert_not_called()
msg = 'foo bar baz'
self.assertIsNone(self.parser.exit(123, msg))
mock__print_message.assert_called_once_with(msg)
@patch.object(parser.ArgumentParser, 'error')
def test_error(self, mock_supercls_error: MagicMock):
with self.assertRaises(ParserError):
self.parser.error(FOO + BAR)
mock_supercls_error.assert_called_once_with(message=FOO + BAR)
@patch.object(parser.ArgumentParser, 'print_help')
def test_print_help(self, mock_print_help: MagicMock):
arg = MagicMock()
with self.assertRaises(HelpRequested):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
@patch.object(parser, '_get_arg_type_wrapper')
@patch.object(parser.ArgumentParser, 'add_argument')
def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_arg_type_wrapper: MagicMock):
mock_add_argument.return_value = expected_output = 'action'
mock__get_arg_type_wrapper.return_value = mock_type = 'fake'
foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex
bar_default, baz_default, boo_default = 1, 0.1, 1j
def func(foo: foo_type, *args: args_type, bar: bar_type = bar_default, baz: baz_type = baz_default,
boo: boo_type = boo_default, flag: bool = False):
return foo, args, bar, baz, boo, flag
param_foo, param_args, param_bar, param_baz, param_boo, param_flag = signature(func).parameters.values()
kwargs = {FOO + BAR: 'xyz'}
self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs))
mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(foo_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs))
mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(args_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs))
mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(bar_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs))
mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(baz_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs))
mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs)
mock__get_arg_type_wrapper.assert_called_once_with(boo_type)
mock_add_argument.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs))
mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs)
mock__get_arg_type_wrapper.assert_not_called()
@patch.object(parser.ControlParser, 'add_function_arg')
def test_add_function_args(self, mock_add_function_arg: MagicMock):
def func(foo: str, *args: int, bar: float = 0.1):
return foo, args, bar
_, param_args, param_bar = signature(func).parameters.values()
self.assertIsNone(self.parser.add_function_args(func, omit=['foo']))
mock_add_function_arg.assert_has_calls([
call(param_args, help=repr(param_args.annotation)),
call(param_bar, help=repr(param_bar.annotation)),
])
class RestTestCase(TestCase):
def test__get_arg_type_wrapper(self):
type_wrap = parser._get_arg_type_wrapper(int)
self.assertEqual('int', type_wrap.__name__)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13'))

View File

@ -1,11 +1,33 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.server` module.
"""
import asyncio
import logging
import os
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import server
from asyncio_taskpool.client import ControlClient, UnixControlClient
from asyncio_taskpool.control import server
from asyncio_taskpool.control.client import ControlClient, TCPControlClient, UnixControlClient
FOO, BAR = 'foo', 'bar'
@ -24,7 +46,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
self.abstract_patcher = patch('asyncio_taskpool.control.server.ControlServer.__abstractmethods__', set())
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_pool = MagicMock()
self.kwargs = {FOO: 123, BAR: 456}
@ -98,6 +120,51 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_create_task.assert_called_once_with(mock_awaitable)
class TCPControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.mock_pool = MagicMock()
self.host, self.port = 'localhost', 12345
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.TCPControlServer(pool=self.mock_pool, host=self.host, port=self.port, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test__client_class(self):
self.assertEqual(TCPControlClient, self.server._client_class)
def test_init(self):
self.assertEqual(self.host, self.server._host)
self.assertEqual(self.port, self.server._port)
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
@patch.object(server, 'start_server')
async def test__get_server_instance(self, mock_start_server: AsyncMock):
mock_start_server.return_value = expected_output = 'totally_a_server'
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
args = [mock_callback]
output = await self.server._get_server_instance(*args, **mock_kwargs)
self.assertEqual(expected_output, output)
mock_start_server.assert_called_once_with(mock_callback, self.host, self.port, **mock_kwargs)
def test__final_callback(self):
self.assertIsNone(self.server._final_callback())
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@ -128,9 +195,9 @@ class UnixControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(Path(self.path), self.server._socket_path)
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
@patch.object(server, 'start_unix_server')
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
async def test__get_server_instance(self):
expected_output = 'totally_a_server'
self.server._start_unix_server = mock_start_unix_server = AsyncMock(return_value=expected_output)
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
args = [mock_callback]
output = await self.server._get_server_instance(*args, **mock_kwargs)

View File

@ -0,0 +1,207 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session` module.
"""
import json
from argparse import ArgumentError, Namespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = session.log.level
session.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
session.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
self.mock_client_class_name = FOO + BAR
self.mock_server = MagicMock(pool=self.mock_pool,
client_class_name=self.mock_client_class_name)
self.mock_reader = MagicMock()
self.mock_writer = MagicMock()
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
def test_init(self):
self.assertEqual(self.mock_server, self.session._control_server)
self.assertEqual(self.mock_pool, self.session._pool)
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
@patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs))
mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
@patch.object(session, 'return_or_exception')
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
def prop_get(_): pass
def prop_set(_): pass
prop = property(prop_get, prop_set)
kwargs = {'value': 'something'}
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_return_or_exception.return_value = val = 420.69
self.assertIsNone(await self.session._exec_property_and_respond(prop))
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
self.mock_writer.write.assert_called_once_with(str(val).encode())
@patch.object(session, 'ControlParser')
async def test_client_handshake(self, mock_parser_cls: MagicMock):
mock_add_subparsers, mock_add_class_commands = MagicMock(), MagicMock()
mock_parser = MagicMock(add_subparsers=mock_add_subparsers, add_class_commands=mock_add_class_commands)
mock_parser_cls.return_value = mock_parser
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '',
'usage': f'[-h] [{CMD}] ...'
}
expected_subparsers_kwargs = {
'title': "Commands",
'metavar': "(A command followed by '-h' or '--help' will show command-specific help.)"
}
self.assertIsNone(await self.session.client_handshake())
self.assertEqual(mock_parser, self.session._parser)
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session.ControlSession, '_exec_property_and_respond')
@patch.object(session.ControlSession, '_exec_method_and_respond')
async def test__parse_command(self, mock__exec_method_and_respond: AsyncMock,
mock__exec_property_and_respond: AsyncMock):
def method(_): pass
prop = property(method)
msg = 'asdf asd as a'
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called()
mock__exec_method_and_respond.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
bad_command = 'definitely not a function or property'
mock_parse_args.return_value = Namespace(**{CMD: bad_command}, **kwargs)
with patch.object(session, 'CommandError') as cmd_err_cls:
cmd_err_cls.return_value = exc = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
cmd_err_cls.assert_called_once_with(f"Unknown command object: {bad_command}")
mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty():
self.mock_reader.read.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg)
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.drain.assert_not_awaited()

View File

@ -0,0 +1,85 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.group_register` module.
"""
from asyncio.locks import Lock
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import group_register
FOO, BAR = 'foo', 'bar'
class TaskGroupRegisterTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.reg = group_register.TaskGroupRegister()
def test_init(self):
ids = [FOO, BAR, 1, 2]
reg = group_register.TaskGroupRegister(*ids)
self.assertSetEqual(set(ids), reg._ids)
self.assertIsInstance(reg._lock, Lock)
def test___contains__(self):
self.reg._ids = {1, 2, 3}
for i in self.reg._ids:
self.assertTrue(i in self.reg)
self.assertFalse(4 in self.reg)
@patch.object(group_register, 'iter', return_value=FOO)
def test___iter__(self, mock_iter: MagicMock):
self.assertEqual(FOO, self.reg.__iter__())
mock_iter.assert_called_once_with(self.reg._ids)
def test___len__(self):
self.reg._ids = [1, 2, 3, 4]
self.assertEqual(4, len(self.reg))
def test_add(self):
self.assertSetEqual(set(), self.reg._ids)
self.assertIsNone(self.reg.add(123))
self.assertSetEqual({123}, self.reg._ids)
def test_discard(self):
self.reg._ids = {123}
self.assertIsNone(self.reg.discard(0))
self.assertIsNone(self.reg.discard(999))
self.assertIsNone(self.reg.discard(123))
self.assertSetEqual(set(), self.reg._ids)
async def test_acquire(self):
self.assertFalse(self.reg._lock.locked())
await self.reg.acquire()
self.assertTrue(self.reg._lock.locked())
def test_release(self):
self.reg._lock._locked = True
self.assertTrue(self.reg._lock.locked())
self.reg.release()
self.assertFalse(self.reg._lock.locked())
async def test_contextmanager(self):
self.assertFalse(self.reg._lock.locked())
async with self.reg as nothing:
self.assertIsNone(nothing)
self.assertTrue(self.reg._lock.locked())
self.assertFalse(self.reg._lock.locked())

View File

@ -87,10 +87,34 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_task_str(self):
self.assertEqual("task", helpers.tasks_str(1))
self.assertEqual("tasks", helpers.tasks_str(0))
self.assertEqual("tasks", helpers.tasks_str(-1))
self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42))
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}
something else
even more
""")
output = helpers.get_first_doc_line(mock_obj)
self.assertEqual(expected_output, output)
async def test_return_or_exception(self):
expected_output = '420'
mock_func = AsyncMock(return_value=expected_output)
args = (1, 3, 5)
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_awaited_once_with(*args, **kwargs)
mock_func = MagicMock(return_value=expected_output)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
class TestException(Exception):
pass
test_exception = TestException()
mock_func = MagicMock(side_effect=test_exception)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(test_exception, output)
mock_func.assert_called_once_with(*args, **kwargs)

View File

@ -18,19 +18,20 @@ __doc__ = """
Unittests for the `asyncio_taskpool.pool` module.
"""
import asyncio
from asyncio.exceptions import CancelledError
from asyncio.queues import Queue
from asyncio.locks import Semaphore
from asyncio.queues import QueueEmpty
from datetime import datetime
from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions
from asyncio_taskpool.constants import DATETIME_FORMAT
EMPTY_LIST, EMPTY_DICT = [], {}
FOO, BAR = 'foo', 'bar'
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
FOO, BAR, BAZ = 'foo', 'bar', 'baz'
class TestException(Exception):
@ -45,19 +46,12 @@ class CommonTestCase(IsolatedAsyncioTestCase):
task_pool: pool.BaseTaskPool
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def get_task_pool_init_params(self) -> dict:
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
def setUp(self) -> None:
self.log_lvl = pool.log.level
pool.log.setLevel(999)
self._pools = self.TEST_CLASS._pools
# These three methods are called during initialization, so we mock them by default during setup:
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
@ -76,6 +70,7 @@ class CommonTestCase(IsolatedAsyncioTestCase):
self._add_pool_patcher.stop()
self.pool_size_patcher.stop()
self.dunder_str_patcher.stop()
pool.log.setLevel(self.log_lvl)
class BaseTaskPoolTestCase(CommonTestCase):
@ -88,19 +83,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertEqual(0, self.task_pool._num_started)
self.assertEqual(0, self.task_pool._num_cancellations)
self.assertFalse(self.task_pool._locked)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertFalse(self.task_pool._closed)
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with()
@ -143,26 +142,56 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertFalse(self.task_pool._locked)
def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz']
self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self):
self.task_pool._num_cancelled = 3
self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_cancellations(self):
self.task_pool._num_cancellations = 3
self.assertEqual(3, self.task_pool.num_cancellations)
def test_num_ended(self):
self.task_pool._num_ended = 3
self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self):
self.task_pool._num_cancelled = cancelled = 69
self.task_pool._num_ended = ended = 420
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
self.task_pool._num_cancellations = num_cancellations = 69
num_ended = 420
self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test_get_group_ids(self):
group_name, ids = 'abcdef', [1, 2, 3]
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name))
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self):
self.task_pool._closed = True
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try:
with self.assertRaises(AssertionError):
self.task_pool._check_start(awaitable=None, function=None)
with self.assertRaises(AssertionError):
self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function)
with self.assertRaises(exceptions.NotCoroutine):
self.task_pool._check_start(awaitable=mock_coroutine_function, function=None)
with self.assertRaises(exceptions.NotCoroutine):
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False
self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True))
finally:
await mock_coroutine
def test__task_name(self):
i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@ -171,12 +200,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
self.task_pool._num_cancellations = cancelled = 3
self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
self.assertNotIn(task_id, self.task_pool._tasks_running)
self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -184,15 +213,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
self.task_pool._running[task_id] = mock_task
self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertNotIn(task_id, self.task_pool._tasks_running)
self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -200,11 +227,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock_execute_optional.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertNotIn(task_id, self.task_pool._tasks_cancelled)
self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -246,92 +272,52 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock,
mock_create_task: MagicMock):
def reset_mocks() -> None:
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start')
async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
self.task_pool._num_started = count = 123
self.task_pool._enough_room._value = room = 123
def check_nothing_changed() -> None:
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
check_nothing_changed()
self.task_pool._locked = True
ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
check_nothing_changed()
ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
group_name, ignore_lock = 'testgroup', True
output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock)
self.assertEqual(room - 1, self.task_pool._enough_room._value)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name])
mock_reg_cls.assert_called_once_with()
mock_group_reg.__aenter__.assert_awaited_once_with()
mock_group_reg.add.assert_called_once_with(count)
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
reset_mocks()
self.task_pool._counter = count
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO)
self.assertEqual(mock_task, self.task_pool._tasks_running[count])
mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task
self.task_pool._tasks_running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
del self.task_pool._ended[task_id]
del self.task_pool._tasks_ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@ -344,263 +330,416 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._running = {1: mock_task1, 2: mock_task2}
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_task1.cancel.assert_called_once_with(msg=FOO)
mock_task2.cancel.assert_called_once_with(msg=FOO)
def test__cancel_and_remove_all_from_group(self):
task_id = 555
mock_cancel = MagicMock()
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
class MockRegister(set, MagicMock):
pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
mock_cancel.assert_called_once_with(msg=FOO)
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[FOO] = mock_group_reg
with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called()
mock_grp_aenter.assert_not_called()
mock_grp_aexit.assert_not_called()
self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[BAR] = mock_group_reg
self.assertIsNone(await self.task_pool.cancel_all(FOO))
mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
async def test_flush(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._interrupt_flag.set()
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
async def test_gather_and_close(self):
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()}
async def test_gather(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
mock_queue_join = AsyncMock()
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert not self.task_pool._locked
with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running)
self.assertListEqual(self.task_pool._before_gathering, before_gather)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
await self.task_pool.gather_and_close()
self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True
def check_assertions(output) -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
check_assertions(await self.task_pool.gather(return_exceptions=True))
self.task_pool._before_gathering = [mock_queue_join()]
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
check_assertions(await self.task_pool.gather(return_exceptions=True))
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_before_gather.assert_awaited_once_with()
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._closed)
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 12345
mock_awaitable = MagicMock()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
self.base_class_init = self.base_class_init_patcher.start()
super().setUp()
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self):
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2}
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO))
self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
mock_task1.cancel.assert_not_called()
mock_task2.cancel.assert_not_called()
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled)
mock_task1.cancel.assert_called_once_with()
mock_task2.cancel.assert_called_once_with()
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
@patch.object(pool.TaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock,
mock_base__cancel_and_remove_all_from_group: MagicMock):
group_name, group_reg, msg = 'xyz', MagicMock(), FOO
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg))
mock__cancel_group_meta_tasks.assert_called_once_with(group_name)
mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg)
@patch.object(pool.BaseTaskPool, 'cancel_group')
async def test_cancel_group(self, mock_base_cancel_group: AsyncMock):
group_name, msg = 'abc', 'xyz'
await self.task_pool.cancel_group(group_name, msg=msg)
mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg)
@patch.object(pool.BaseTaskPool, 'cancel_all')
async def test_cancel_all(self, mock_base_cancel_all: AsyncMock):
msg = 'xyz'
await self.task_pool.cancel_all(msg=msg)
mock_base_cancel_all.assert_awaited_once_with(msg=msg)
def test__pop_ended_meta_tasks(self):
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1}
mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True)
self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3}
expected_output = {mock_done_task1, mock_done_task2, mock_done_task3}
output = self.task_pool._pop_ended_meta_tasks()
self.assertSetEqual(expected_output, output)
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
@patch.object(pool.TaskPool, '_pop_ended_meta_tasks')
@patch.object(pool.BaseTaskPool, 'flush')
async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock):
mock_ended_meta_task = AsyncMock()
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=False))
mock_base_flush.assert_awaited_once_with(return_exceptions=False)
mock__pop_ended_meta_tasks.assert_called_once_with()
mock_ended_meta_task.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with()
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool.BaseTaskPool, 'gather_and_close')
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
mock_meta_task1.assert_awaited_once_with()
mock_meta_task2.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool, 'datetime')
def test__generate_group_name(self, mock_datetime: MagicMock):
prefix, func = 'x y z', AsyncMock(__name__=BAR)
dt = datetime(1776, 7, 4, 0, 0, 1)
mock_datetime.now = MagicMock(return_value=dt)
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
output = pool.TaskPool._generate_group_name(prefix, func)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock):
group_name = FOO + BAR
mock_awaitable = object()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
mock__start_task.assert_has_awaits(3 * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
])
mock_func.reset_mock()
mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(num * [call(*args)])
mock__start_task.assert_has_awaits(num * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
])
@patch.object(pool.TaskPool, '_apply_one')
async def test_apply(self, mock__apply_one: AsyncMock):
mock__apply_one.return_value = mock_id = 67890
mock_func, num = MagicMock(), 3
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start')
async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id]
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
self.task_pool._task_groups = {}
async def test__queue_producer(self):
def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock_group_reg.__aenter__.assert_awaited_once_with()
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine)
mock_group_reg.__aexit__.assert_awaited_once()
mock_task_future.assert_awaited_once_with()
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
check_assertions(group_name, output)
mock__generate_group_name.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
mock_group_reg.__aenter__.reset_mock()
mock__apply_num.reset_mock()
mock_create_task.reset_mock()
mock_group_reg.__aexit__.reset_mock()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
check_assertions(generated_name, output)
mock__generate_group_name.assert_called_once_with('apply', mock_func)
@patch.object(pool, 'Queue')
async def test__queue_producer(self, mock_queue_cls: MagicMock):
mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put)
args = (FOO, BAR, 123)
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_has_awaits([call(arg) for arg in args])
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
item1, item2, item3 = FOO, 420, 69
arg_iter = iter([item1, item2, item3])
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
with self.assertRaises(StopIteration):
next(arg_iter)
mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial')
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
mock_partial: MagicMock):
mock_partial.return_value = queue_callback = 'not really'
mock_star_function.return_value = awaitable = 'totally an awaitable'
q, arg = Queue(), 420.69
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
mock_put.side_effect = [CancelledError, None]
arg_iter = iter([item1, item2, item3])
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
mock_queue.item_processed.assert_has_calls([call(), call()])
self.assertListEqual([item2, item3], list(arg_iter))
@patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer')
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
task_id, mock_q = 420, MagicMock()
mock_func, stars = MagicMock(), 3
mock_wait = AsyncMock()
mock_flag = MagicMock(wait=mock_wait)
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb))
mock_wait.assert_awaited_once_with()
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
semaphore, mock_end_cb = Semaphore(1), MagicMock()
wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb)
task_id = 1234
await wrapped(task_id)
self.assertEqual(2, semaphore._value)
mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,))
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
@patch.object(pool, 'Semaphore')
@patch.object(pool.TaskPool, '_get_map_end_callback')
async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
mock__start_task: AsyncMock, mock_star_function: MagicMock):
mock__get_map_end_callback.return_value = map_cb = MagicMock()
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
mock_star_function.return_value = awaitable = 'totally an awaitable'
arg1, arg2 = 123456789, 'function argument'
mock_q_maxsize = 3
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
group_name, mock_func, stars = 'whatever', MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(semaphore.locked())
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
mock_create_task: MagicMock, mock_iter: MagicMock):
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
mock_join_queue.return_value = mock_join = 'awaitable'
mock_iter.return_value = args_iter = iter(args)
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(num_tasks, output_q.qsize())
for arg in args[:num_tasks]:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
for arg in args[num_tasks:]:
self.assertEqual(arg, next(args_iter))
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool, 'Queue')
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start')
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_queue_cls.return_value = mock_q = MagicMock()
mock_join_queue.return_value = fake_join = object()
mock__queue_producer.return_value = fake_producer = object()
mock__queue_consumer.return_value = fake_consumer = object()
fake_task1, fake_task2 = object(), object()
mock_create_task.side_effect = [fake_task1, fake_task2]
self.task_pool._before_gathering.clear()
mock_join_queue.reset_mock()
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool, 'Event')
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
mock_event_cls: MagicMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_flag_set = MagicMock()
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
group_name, group_size = 'onetwothree', 0
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._locked = False
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_flag_set.assert_not_called()
with self.assertRaises(ValueError):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func)
self.task_pool._locked = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)])
mock_flag_set.assert_called_once_with()
mock__check_start.reset_mock()
group_size = 1234
self.task_pool._task_groups = {group_name: MagicMock()}
with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
self.task_pool._before_gathering = []
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
mock__check_start.assert_called_once_with(function=func)
mock_reg_cls.assert_called_once_with()
self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with()
mock_queue_cls.assert_called_once_with(maxsize=group_size)
mock_join_queue.assert_called_once_with(mock_q)
self.assertListEqual([fake_join], self.task_pool._before_gathering)
mock__queue_producer.assert_called_once()
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock):
@patch.object(pool.TaskPool, '_generate_group_name')
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('map', mock_func)
@patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock):
@patch.object(pool.TaskPool, '_generate_group_name')
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
@patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock):
@patch.object(pool.TaskPool, '_generate_group_name')
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
class SimpleTaskPoolTestCase(CommonTestCase):
@ -667,7 +806,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def test_stop(self, mock_cancel: MagicMock):
num = 2
id1, id2, id3 = 5, 6, 7
self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR}
self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR}
output = self.task_pool.stop(num)
expected_output = [id3, id2]
self.assertEqual(expected_output, output)
@ -689,3 +828,10 @@ class SimpleTaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output)
mock_num_running.assert_called_once_with()
mock_stop.assert_called_once_with(num)
def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock:
mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock()
mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit,
add=mock_grp_add)
return mock_group_reg

View File

@ -0,0 +1,43 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.queue_context` module.
"""
from unittest import IsolatedAsyncioTestCase
from unittest.mock import MagicMock, patch
from asyncio_taskpool.queue_context import Queue
class QueueTestCase(IsolatedAsyncioTestCase):
def test_item_processed(self):
queue = Queue()
queue._unfinished_tasks = 1000
queue.item_processed()
self.assertEqual(999, queue._unfinished_tasks)
@patch.object(Queue, 'item_processed')
async def test_contextmanager(self, mock_item_processed: MagicMock):
queue = Queue()
item = 'foo'
queue.put_nowait(item)
async with queue as item_from_queue:
self.assertEqual(item, item_from_queue)
mock_item_processed.assert_not_called()
mock_item_processed.assert_called_once_with()

View File

@ -1,12 +1,18 @@
# Using `asyncio-taskpool`
## Contents
- [Contents](#contents)
- [Minimal example for `SimpleTaskPool`](#minimal-example-for-simpletaskpool)
- [Advanced example for `TaskPool`](#advanced-example-for-taskpool)
- [Control server example](#control-server-example)
## Minimal example for `SimpleTaskPool`
With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run.
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
```python
import logging
@ -28,54 +34,57 @@ async def work(n: int) -> None:
"""
for i in range(n):
await asyncio.sleep(1)
print("did", i)
print("> did", i)
async def main() -> None:
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
pool.lock() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool
await pool.gather_and_close() # awaits all tasks, then flushes the pool
if __name__ == '__main__':
asyncio.run(main())
```
### Output
<details>
<summary>Output: (Click to expand)</summary>
```
SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0
Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2
did 0
did 0
did 0
> did 0
> did 0
> did 0
Started SimpleTaskPool-0_Task-3
did 1
did 1
did 1
did 0
> did 1
> did 1
> did 1
> did 0
> did 2
> did 2
SimpleTaskPool-0 is locked!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2
did 2
did 2
did 3
did 3
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
> did 3
> did 3
Ended SimpleTaskPool-0_Task-0
Ended SimpleTaskPool-0_Task-1
did 4
did 4
> did 4
> did 4
```
</details>
## Advanced example for `TaskPool`
@ -83,9 +92,7 @@ This time, we want to start tasks from _different_ coroutine functions **and** w
As with the simple example, we need "worker" coroutine functions that can do something asynchronously, as well as a main coroutine function that sets up the pool, starts the tasks, and eventually awaits them.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
```python
import logging
@ -101,133 +108,162 @@ async def work(start: int, stop: int, step: int = 1) -> None:
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
for i in range(start, stop, step):
await asyncio.sleep(1)
print("work with", i)
print("> work with", i)
async def other_work(a: int, b: int) -> None:
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
for i in range(a, b):
await asyncio.sleep(0.5)
print("other_work with", i)
print("> other_work with", i)
async def main() -> None:
# Initialize a new task pool instance and limit its size to 3 tasks.
pool = TaskPool(3)
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
print("Called `apply`")
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
print("> Called `apply`")
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
# Let the tasks work for a bit.
await asyncio.sleep(1.5)
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
# positional arguments by using `starmap`, but have **no more than two of those** run concurrently.
# positional arguments by using `starmap`, but we want no more than two of those to run concurrently.
# Since we set our pool size to 3, and already have two tasks working within the pool,
# only the first one of these will start immediately (and receive ID 2).
# The second one will start (with ID 3), only once there is room in the pool,
# which -- in this example -- will be the case after ID 2 ends;
# until then the `starmap` method call **will block**!
# which -- in this example -- will be the case after ID 2 ends.
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
# **only** once there is room in the pool **and** no more than one of these last four tasks is running.
# only once there is room in the pool and no more than one other task of these new ones is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
print("Calling `starmap`...")
await pool.starmap(other_work, args_list, num_tasks=2)
print("`starmap` returned")
await pool.starmap(other_work, args_list, group_size=2)
print("> Called `starmap`")
# Now we lock the pool, so that we can safely await all our tasks.
pool.lock()
# Finally, we block, until all tasks have ended.
print("Called `gather`")
await pool.gather()
print("Done.")
print("> Calling `gather_and_close`...")
await pool.gather_and_close()
print("> Done.")
if __name__ == '__main__':
asyncio.run(main())
```
### Output
Additional comments for the output are provided with `<---` next to the output lines.
<details>
<summary>Output: (Click to expand)</summary>
(Keep in mind that the logger and `print` asynchronously write to `stdout`.)
```
TaskPool-0 initialized
Started TaskPool-0_Task-0
Started TaskPool-0_Task-1
Called `apply`
work with 100
work with 100
Calling `starmap`... <--- notice that this blocks as expected
Started TaskPool-0_Task-2
work with 110
work with 110
other_work with 0
other_work with 1
work with 120
work with 120
other_work with 2
other_work with 3
work with 130
work with 130
other_work with 4
other_work with 5
work with 140
work with 140
other_work with 6
other_work with 7
work with 150
work with 150
other_work with 8
Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
> Called `apply`
> work with 100
> work with 100
> Called `starmap` <--- notice that this immediately returns, even before Task-2 is started
> Calling `gather_and_close`... <--- this blocks `main()` until all tasks have ended
TaskPool-0 is locked!
Started TaskPool-0_Task-2 <--- at this point the pool is full
> work with 110
> work with 110
> other_work with 0
> other_work with 1
> work with 120
> work with 120
> other_work with 2
> other_work with 3
> work with 130
> work with 130
> other_work with 4
> other_work with 5
> work with 140
> work with 140
> other_work with 6
> other_work with 7
> work with 150
> work with 150
> other_work with 8
Ended TaskPool-0_Task-2 <--- this frees up room for one more task from `starmap`
Started TaskPool-0_Task-3
other_work with 9
`starmap` returned
Called `gather` <--- now this will block `main()` until all tasks have ended
work with 160
work with 160
other_work with 10
other_work with 11
work with 170
work with 170
other_work with 12
other_work with 13
work with 180
work with 180
other_work with 14
other_work with 15
> other_work with 9
> work with 160
> work with 160
> other_work with 10
> other_work with 11
> work with 170
> work with 170
> other_work with 12
> other_work with 13
> work with 180
> work with 180
> other_work with 14
> other_work with 15
Ended TaskPool-0_Task-0
Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start
Started TaskPool-0_Task-4
work with 190
work with 190
other_work with 16
other_work with 20
other_work with 17
other_work with 21
other_work with 18
other_work with 22
other_work with 19
Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
> work with 190
> work with 190
> other_work with 16
> other_work with 17
> other_work with 20
> other_work with 18
> other_work with 21
Ended TaskPool-0_Task-3 <--- now that only Task-4 of the group remains, Task-5 starts
Started TaskPool-0_Task-5
other_work with 23
other_work with 30
other_work with 24
other_work with 31
other_work with 25
other_work with 32
other_work with 26
other_work with 33
other_work with 27
other_work with 34
other_work with 28
other_work with 35
> other_work with 19
> other_work with 22
> other_work with 23
> other_work with 30
> other_work with 24
> other_work with 31
> other_work with 25
> other_work with 32
> other_work with 26
> other_work with 33
> other_work with 27
> other_work with 34
> other_work with 28
> other_work with 35
> other_work with 29
> other_work with 36
Ended TaskPool-0_Task-4
other_work with 29
other_work with 36
other_work with 37
other_work with 38
other_work with 39
Done.
> other_work with 37
> other_work with 38
> other_work with 39
Ended TaskPool-0_Task-5
> Done.
```
(Added comments with `<---` next to the output lines.)
Keep in mind that the logger and `print` asynchronously write to `stdout`, so the order of lines in your output may be slightly different.
</details>
## Control server example
One of the main features of `asyncio-taskpool` is the ability to control a task pool "from the outside" at runtime.
The [example_server.py](./example_server.py) script launches a couple of worker tasks within a `SimpleTaskPool` instance and then starts a `TCPControlServer` instance for that task pool. The server is configured to locally bind to port `9999` and is stopped automatically after the "work" is done.
To run the script:
```shell
python usage/example_server.py
```
You can then connect to the server via the command line interface:
```shell
python -m asyncio_taskpool.control tcp localhost 9999
```
The CLI starts a `TCPControlClient` that connects to our example server. Once the connection is established, it gives you an input prompt allowing you to issue commands to the task pool:
```
Connected to SimpleTaskPool-0
Type '-h' to get help and usage instructions for all available commands.
>
```
It may be useful to run the server script and the client interface in two separate terminal windows side by side. The server script is configured with a verbose logger and will react to any commands issued by the client with detailed log messages in the terminal.
---
© 2022 Daniil Fajnberg

View File

@ -23,7 +23,7 @@ Use the main CLI client to interface at the socket.
import asyncio
import logging
from asyncio_taskpool import SimpleTaskPool, UnixControlServer
from asyncio_taskpool import SimpleTaskPool, TCPControlServer
from asyncio_taskpool.constants import PACKAGE_NAME
@ -34,11 +34,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
async def work(item: int) -> None:
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
await asyncio.sleep(1)
print("worked on", item)
print("worked on", item, flush=True)
async def worker(q: asyncio.Queue) -> None:
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
"""Simulates doing asynchronous work that takes a bit of time to finish."""
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
while True:
# We want to block here, until we can get the next item from the queue.
@ -65,21 +65,21 @@ async def main() -> None:
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
# Since we don't need any "work" done anymore, we can get rid of our control server by cancelling the task.
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.
pool.stop_all()
pool.lock()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
pool.stop_all()
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values.
await pool.gather(return_exceptions=True)
await pool.gather_and_close(return_exceptions=True)
await control_server_task