generated from daniil-berg/boilerplate-py
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
ce0f9a1f65 | |||
5dad4ab0c7 | |||
ae6bb1bd17 | |||
e501a849f3 | |||
ed6badb088 | |||
c63f079da4 | |||
4994135062 | |||
d0c0177681 | |||
bac7b32342 | |||
96d01e7259 | |||
3f3eb7ce38 | |||
05d51eface | |||
b6aed727e9 | |||
c9a8d9ecd1 | |||
538b9cc91c | |||
3fb451a00e | |||
be03097bf4 |
@ -5,8 +5,10 @@ omit =
|
|||||||
.venv/*
|
.venv/*
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
fail_under = 100
|
|
||||||
show_missing = True
|
show_missing = True
|
||||||
skip_covered = False
|
skip_covered = False
|
||||||
|
exclude_lines =
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
if __name__ == ['"]__main__['"]:
|
||||||
omit =
|
omit =
|
||||||
tests/*
|
tests/*
|
||||||
|
14
README.md
14
README.md
@ -14,12 +14,20 @@ If you need control over a task pool at runtime, you can launch an asynchronous
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like:
|
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from asyncio_taskpool import SimpleTaskPool
|
from asyncio_taskpool import SimpleTaskPool
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
async def work(foo, bar): ...
|
async def work(foo, bar): ...
|
||||||
|
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
pool = SimpleTaskPool(work, args=('xyz', 420))
|
pool = SimpleTaskPool(work, args=('xyz', 420))
|
||||||
await pool.start(5)
|
await pool.start(5)
|
||||||
@ -27,11 +35,11 @@ async def main():
|
|||||||
pool.stop(3)
|
pool.stop(3)
|
||||||
...
|
...
|
||||||
pool.lock()
|
pool.lock()
|
||||||
await pool.gather()
|
await pool.gather_and_close()
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
|
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
|
||||||
|
|
||||||
For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md).
|
For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md).
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 0.2.1
|
version = 0.5.0
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -20,4 +20,4 @@ Brings the main classes up to package level for import convenience.
|
|||||||
|
|
||||||
|
|
||||||
from .pool import TaskPool, SimpleTaskPool
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
from .server import UnixControlServer
|
from .server import TCPControlServer, UnixControlServer
|
||||||
|
@ -25,15 +25,16 @@ from asyncio import run
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from .client import ControlClient, UnixControlClient
|
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||||
from .constants import PACKAGE_NAME
|
from .constants import PACKAGE_NAME
|
||||||
from .pool import TaskPool
|
from .pool import TaskPool
|
||||||
from .server import ControlServer
|
from .server import TCPControlServer, UnixControlServer
|
||||||
|
|
||||||
|
|
||||||
CONN_TYPE = 'conn_type'
|
CONN_TYPE = 'conn_type'
|
||||||
UNIX, TCP = 'unix', 'tcp'
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
SOCKET_PATH = 'path'
|
SOCKET_PATH = 'path'
|
||||||
|
HOST, PORT = 'host', 'port'
|
||||||
|
|
||||||
|
|
||||||
def parse_cli() -> Dict[str, Any]:
|
def parse_cli() -> Dict[str, Any]:
|
||||||
@ -46,7 +47,18 @@ def parse_cli() -> Dict[str, Any]:
|
|||||||
unix_parser.add_argument(
|
unix_parser.add_argument(
|
||||||
SOCKET_PATH,
|
SOCKET_PATH,
|
||||||
type=Path,
|
type=Path,
|
||||||
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
|
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
|
||||||
|
f"listening."
|
||||||
|
)
|
||||||
|
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
|
||||||
|
tcp_parser.add_argument(
|
||||||
|
HOST,
|
||||||
|
help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
||||||
|
)
|
||||||
|
tcp_parser.add_argument(
|
||||||
|
PORT,
|
||||||
|
type=int,
|
||||||
|
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
||||||
)
|
)
|
||||||
return vars(parser.parse_args())
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
@ -54,10 +66,9 @@ def parse_cli() -> Dict[str, Any]:
|
|||||||
async def main():
|
async def main():
|
||||||
kwargs = parse_cli()
|
kwargs = parse_cli()
|
||||||
if kwargs[CONN_TYPE] == UNIX:
|
if kwargs[CONN_TYPE] == UNIX:
|
||||||
client = UnixControlClient(path=kwargs[SOCKET_PATH])
|
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
|
||||||
elif kwargs[CONN_TYPE] == TCP:
|
elif kwargs[CONN_TYPE] == TCP:
|
||||||
# TODO: Implement the TCP client class
|
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
|
||||||
client = UnixControlClient(path=kwargs[SOCKET_PATH])
|
|
||||||
else:
|
else:
|
||||||
print("Invalid connection type", file=sys.stderr)
|
print("Invalid connection type", file=sys.stderr)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
@ -19,66 +19,173 @@ Classes of control clients for a simply interface to a task pool control server.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
from asyncio.streams import StreamReader, StreamWriter, open_connection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from asyncio_taskpool import constants
|
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||||
from asyncio_taskpool.types import ClientConnT
|
from .types import ClientConnT, PathT
|
||||||
|
|
||||||
|
|
||||||
class ControlClient(ABC):
|
class ControlClient(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for a simple implementation of a task pool control client.
|
||||||
|
|
||||||
|
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
|
||||||
|
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
|
||||||
|
This is a minimal working implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def client_info() -> dict:
|
||||||
|
"""Returns a dictionary of client information relevant for the handshake with the server."""
|
||||||
|
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
"""
|
||||||
|
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
|
||||||
|
|
||||||
|
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs`
|
||||||
|
(unpacked) as keyword-arguments.
|
||||||
|
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
|
||||||
|
`None` and `None`, if it failed to establish the defined connection.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
|
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
|
||||||
self._conn_kwargs = conn_kwargs
|
self._conn_kwargs = conn_kwargs
|
||||||
self._connected: bool = False
|
self._connected: bool = False
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
Performs the first interaction with the server providing it with the necessary client information.
|
||||||
|
|
||||||
|
Upon completion, the server's info is printed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
"""
|
||||||
|
self._connected = True
|
||||||
|
writer.write(json.dumps(self.client_info()).encode())
|
||||||
|
await writer.drain()
|
||||||
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
|
def _get_command(self, writer: StreamWriter) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
||||||
|
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
msg = input("> ").strip().lower()
|
||||||
except EOFError:
|
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
|
||||||
msg = constants.CLIENT_EXIT
|
msg = CLIENT_EXIT
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||||
print()
|
print()
|
||||||
return
|
return
|
||||||
if msg == constants.CLIENT_EXIT:
|
if msg == CLIENT_EXIT:
|
||||||
writer.close()
|
writer.close()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
return
|
return
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
|
||||||
|
|
||||||
|
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
|
||||||
|
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
"""
|
||||||
|
cmd = self._get_command(writer)
|
||||||
|
if cmd is None:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
writer.write(msg.encode())
|
# Send the command to the server.
|
||||||
|
writer.write(cmd.encode())
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
return
|
return
|
||||||
print((await reader.read(constants.MSG_BYTES)).decode())
|
# Await the server's response, then print it.
|
||||||
|
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def start(self):
|
async def start(self) -> None:
|
||||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
"""
|
||||||
|
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
|
||||||
|
|
||||||
|
If the connection can not be established, an error message is printed to `stderr` and the method returns.
|
||||||
|
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
|
||||||
|
disconnected-message.
|
||||||
|
"""
|
||||||
|
reader, writer = await self._open_connection(**self._conn_kwargs)
|
||||||
if reader is None:
|
if reader is None:
|
||||||
print("Failed to connect.", file=sys.stderr)
|
print("Failed to connect.", file=sys.stderr)
|
||||||
return
|
return
|
||||||
self._connected = True
|
await self._server_handshake(reader, writer)
|
||||||
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
|
||||||
while self._connected:
|
while self._connected:
|
||||||
await self._interact(reader, writer)
|
await self._interact(reader, writer)
|
||||||
print("Disconnected from control server.")
|
print("Disconnected from control server.")
|
||||||
|
|
||||||
|
|
||||||
class UnixControlClient(ControlClient):
|
class TCPControlClient(ControlClient):
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
"""Task pool control client that expects a TCP socket to be exposed by the control server."""
|
||||||
self._socket_path = Path(conn_kwargs.pop('path'))
|
|
||||||
|
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
|
||||||
|
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments."""
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
super().__init__(**conn_kwargs)
|
super().__init__(**conn_kwargs)
|
||||||
|
|
||||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
"""
|
||||||
|
Wrapper around the `asyncio.open_connection` function.
|
||||||
|
|
||||||
|
Returns a tuple of `None` and `None`, if the connection can not be established;
|
||||||
|
otherwise, the stream-reader and -writer tuple is returned.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return await open_unix_connection(self._socket_path, **kwargs)
|
return await open_connection(self._host, self._port, **kwargs)
|
||||||
|
except ConnectionError as e:
|
||||||
|
print(str(e), file=sys.stderr)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlClient(ControlClient):
|
||||||
|
"""Task pool control client that expects a unix socket to be exposed by the control server."""
|
||||||
|
|
||||||
|
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
|
||||||
|
"""In addition to what the base class does, the `socket_path` is expected as a non-optional argument."""
|
||||||
|
from asyncio.streams import open_unix_connection
|
||||||
|
self._open_unix_connection = open_unix_connection
|
||||||
|
self._socket_path = Path(socket_path)
|
||||||
|
super().__init__(**conn_kwargs)
|
||||||
|
|
||||||
|
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
"""
|
||||||
|
Wrapper around the `asyncio.open_unix_connection` function.
|
||||||
|
|
||||||
|
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
|
||||||
|
otherwise, the stream-reader and -writer tuple is returned.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self._open_unix_connection(self._socket_path, **kwargs)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("No socket at", self._socket_path, file=sys.stderr)
|
print("No socket at", self._socket_path, file=sys.stderr)
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -20,10 +20,28 @@ Constants used by more than one module in the package.
|
|||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
MSG_BYTES = 1024
|
|
||||||
CMD_START = 'start'
|
DEFAULT_TASK_GROUP = ''
|
||||||
CMD_STOP = 'stop'
|
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
|
||||||
CMD_STOP_ALL = 'stop_all'
|
|
||||||
CMD_NUM_RUNNING = 'num_running'
|
|
||||||
CMD_FUNC = 'func'
|
|
||||||
CLIENT_EXIT = 'exit'
|
CLIENT_EXIT = 'exit'
|
||||||
|
|
||||||
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
SESSION_WRITER = 'session_writer'
|
||||||
|
|
||||||
|
|
||||||
|
class CLIENT_INFO:
|
||||||
|
__slots__ = ()
|
||||||
|
TERMINAL_WIDTH = 'terminal_width'
|
||||||
|
|
||||||
|
|
||||||
|
class CMD:
|
||||||
|
__slots__ = ()
|
||||||
|
CMD = 'command'
|
||||||
|
NAME = 'name'
|
||||||
|
POOL_SIZE = 'pool-size'
|
||||||
|
NUM_RUNNING = 'num-running'
|
||||||
|
START = 'start'
|
||||||
|
STOP = 'stop'
|
||||||
|
STOP_ALL = 'stop-all'
|
||||||
|
FUNC_NAME = 'func-name'
|
||||||
|
@ -23,6 +23,10 @@ class PoolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PoolIsClosed(PoolException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PoolIsLocked(PoolException):
|
class PoolIsLocked(PoolException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -43,9 +47,29 @@ class InvalidTaskID(PoolException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidGroupName(PoolException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PoolStillUnlocked(PoolException):
|
class PoolStillUnlocked(PoolException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NotCoroutine(PoolException):
|
class NotCoroutine(PoolException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ServerException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownTaskPoolClass(ServerException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NotATaskPool(ServerException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HelpRequested(ServerException):
|
||||||
|
pass
|
||||||
|
75
src/asyncio_taskpool/group_register.py
Normal file
75
src/asyncio_taskpool/group_register.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
This module contains the definition of the `TaskGroupRegister` class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from asyncio.locks import Lock
|
||||||
|
from collections.abc import MutableSet
|
||||||
|
from typing import Iterator, Set
|
||||||
|
|
||||||
|
|
||||||
|
class TaskGroupRegister(MutableSet):
|
||||||
|
"""
|
||||||
|
This class combines the interface of a regular `set` with that of the `asyncio.Lock`.
|
||||||
|
|
||||||
|
It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for
|
||||||
|
preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of
|
||||||
|
tasks, as well as before starting a task within the group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *task_ids: int) -> None:
|
||||||
|
self._ids: Set[int] = set(task_ids)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def __contains__(self, task_id: int) -> bool:
|
||||||
|
"""Abstract method for the `MutableSet` base class."""
|
||||||
|
return task_id in self._ids
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
"""Abstract method for the `MutableSet` base class."""
|
||||||
|
return iter(self._ids)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Abstract method for the `MutableSet` base class."""
|
||||||
|
return len(self._ids)
|
||||||
|
|
||||||
|
def add(self, task_id: int) -> None:
|
||||||
|
"""Abstract method for the `MutableSet` base class."""
|
||||||
|
self._ids.add(task_id)
|
||||||
|
|
||||||
|
def discard(self, task_id: int) -> None:
|
||||||
|
"""Abstract method for the `MutableSet` base class."""
|
||||||
|
self._ids.discard(task_id)
|
||||||
|
|
||||||
|
async def acquire(self) -> bool:
|
||||||
|
"""Wrapper around the lock's `acquire()` method."""
|
||||||
|
return await self._lock.acquire()
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
"""Wrapper around the lock's `release()` method."""
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
|
||||||
|
await self._lock.acquire()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
|
||||||
|
self._lock.release()
|
@ -21,7 +21,8 @@ Miscellaneous helper functions.
|
|||||||
|
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from asyncio.queues import Queue
|
from asyncio.queues import Queue
|
||||||
from typing import Any, Optional
|
from inspect import getdoc
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
@ -48,3 +49,21 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
|||||||
|
|
||||||
async def join_queue(q: Queue) -> None:
|
async def join_queue(q: Queue) -> None:
|
||||||
await q.join()
|
await q.join()
|
||||||
|
|
||||||
|
|
||||||
|
def tasks_str(num: int) -> str:
|
||||||
|
return "tasks" if num != 1 else "task"
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_doc_line(obj: object) -> str:
|
||||||
|
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
|
||||||
|
try:
|
||||||
|
if iscoroutinefunction(_function_to_execute):
|
||||||
|
return await _function_to_execute(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return _function_to_execute(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
File diff suppressed because it is too large
Load Diff
58
src/asyncio_taskpool/queue_context.py
Normal file
58
src/asyncio_taskpool/queue_context.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
This module contains the definition of an `asyncio.Queue` subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from asyncio.queues import Queue as _Queue
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class Queue(_Queue):
|
||||||
|
"""This just adds a little syntactic sugar to the `asyncio.Queue`."""
|
||||||
|
|
||||||
|
def item_processed(self) -> None:
|
||||||
|
"""
|
||||||
|
Does exactly the same as `task_done()`.
|
||||||
|
|
||||||
|
This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing,
|
||||||
|
invites confusion, and immensely reduces readability (in the context of this library). And readability counts.
|
||||||
|
"""
|
||||||
|
self.task_done()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Any:
|
||||||
|
"""
|
||||||
|
Implements an asynchronous context manager for the queue.
|
||||||
|
|
||||||
|
Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned.
|
||||||
|
It allows writing code this way:
|
||||||
|
>>> queue = Queue()
|
||||||
|
>>> ...
|
||||||
|
>>> async with queue as item:
|
||||||
|
>>> ...
|
||||||
|
"""
|
||||||
|
return await self.get()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""
|
||||||
|
Implements an asynchronous context manager for the queue.
|
||||||
|
|
||||||
|
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
|
||||||
|
but in some situations it makes the codes much cleaner.
|
||||||
|
"""
|
||||||
|
self.item_processed()
|
@ -23,129 +23,149 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio import AbstractServer
|
from asyncio import AbstractServer
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
from asyncio.streams import StreamReader, StreamWriter, start_server
|
||||||
from asyncio.tasks import Task, create_task
|
from asyncio.tasks import Task, create_task
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from . import constants
|
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||||
from .pool import SimpleTaskPool
|
from .pool import TaskPool, SimpleTaskPool
|
||||||
from .client import ControlClient, UnixControlClient
|
from .session import ControlSession
|
||||||
|
from .types import ConnectedCallbackT
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def tasks_str(num: int) -> str:
|
|
||||||
return "tasks" if num != 1 else "task"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
|
|
||||||
cmd = msg.strip().split(' ', 1)
|
|
||||||
if len(cmd) > 1:
|
|
||||||
try:
|
|
||||||
return cmd[0], int(cmd[1])
|
|
||||||
except ValueError:
|
|
||||||
return None, None
|
|
||||||
return cmd[0], None
|
|
||||||
|
|
||||||
|
|
||||||
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
||||||
client_class = ControlClient
|
"""
|
||||||
|
Abstract base class for a task pool control server.
|
||||||
|
|
||||||
|
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
|
||||||
|
connecting to it. The entire interface is defined within that session class.
|
||||||
|
"""
|
||||||
|
_client_class = ControlClient
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@property
|
||||||
|
def client_class_name(cls) -> str:
|
||||||
|
"""Returns the name of the control client class matching the server class."""
|
||||||
|
return cls._client_class.__name__
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||||
|
"""
|
||||||
|
Initializes, starts, and returns an async server instance (Unix or TCP type).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_connected_cb:
|
||||||
|
The callback for when a client connects to the server (as per `asyncio.start_server` or
|
||||||
|
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
|
||||||
|
**kwargs (optional):
|
||||||
|
Keyword arguments to pass into the function that starts the server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The running server object (a type of `asyncio.Server`).
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def final_callback(self) -> None:
|
def _final_callback(self) -> None:
|
||||||
|
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
||||||
self._pool: SimpleTaskPool = pool
|
"""
|
||||||
|
Initializes by merely saving the internal attributes, but without starting the server yet.
|
||||||
|
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
|
||||||
|
tied to one specific task pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool:
|
||||||
|
An instance of a `BaseTaskPool` subclass to tie the server to.
|
||||||
|
**server_kwargs (optional):
|
||||||
|
Keyword arguments that will be passed into the function that starts the server.
|
||||||
|
"""
|
||||||
|
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
||||||
self._server_kwargs = server_kwargs
|
self._server_kwargs = server_kwargs
|
||||||
self._server: Optional[AbstractServer] = None
|
self._server: Optional[AbstractServer] = None
|
||||||
|
|
||||||
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
@property
|
||||||
if num is None:
|
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
||||||
num = 1
|
"""Read-only property for accessing the task pool instance controlled by the server."""
|
||||||
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
|
return self._pool
|
||||||
writer.write(str(await self._pool.start(num)).encode())
|
|
||||||
|
|
||||||
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
def is_serving(self) -> bool:
|
||||||
if num is None:
|
"""Wrapper around the `asyncio.Server.is_serving` method."""
|
||||||
num = 1
|
return self._server.is_serving()
|
||||||
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
|
|
||||||
# the requested number may be greater than the total number of running tasks
|
|
||||||
writer.write(str(self._pool.stop(num)).encode())
|
|
||||||
|
|
||||||
def _stop_all_tasks(self, writer: StreamWriter) -> None:
|
|
||||||
log.debug("%s requests stopping all tasks", self.client_class.__name__)
|
|
||||||
writer.write(str(self._pool.stop_all()).encode())
|
|
||||||
|
|
||||||
def _pool_size(self, writer: StreamWriter) -> None:
|
|
||||||
log.debug("%s requests number of running tasks", self.client_class.__name__)
|
|
||||||
writer.write(str(self._pool.num_running).encode())
|
|
||||||
|
|
||||||
def _pool_func(self, writer: StreamWriter) -> None:
|
|
||||||
log.debug("%s requests pool function", self.client_class.__name__)
|
|
||||||
writer.write(self._pool.func_name.encode())
|
|
||||||
|
|
||||||
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
|
|
||||||
while self._server.is_serving():
|
|
||||||
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
|
|
||||||
if not msg:
|
|
||||||
log.debug("%s disconnected", self.client_class.__name__)
|
|
||||||
break
|
|
||||||
cmd, arg = get_cmd_arg(msg)
|
|
||||||
if cmd == constants.CMD_START:
|
|
||||||
await self._start_tasks(writer, arg)
|
|
||||||
elif cmd == constants.CMD_STOP:
|
|
||||||
self._stop_tasks(writer, arg)
|
|
||||||
elif cmd == constants.CMD_STOP_ALL:
|
|
||||||
self._stop_all_tasks(writer)
|
|
||||||
elif cmd == constants.CMD_NUM_RUNNING:
|
|
||||||
self._pool_size(writer)
|
|
||||||
elif cmd == constants.CMD_FUNC:
|
|
||||||
self._pool_func(writer)
|
|
||||||
else:
|
|
||||||
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
|
|
||||||
writer.write(b"Invalid command!")
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
log.debug("%s connected", self.client_class.__name__)
|
"""
|
||||||
writer.write(str(self._pool).encode())
|
The universal client callback that will be passed into the `_get_server_instance` method.
|
||||||
await writer.drain()
|
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
|
||||||
await self._listen(reader, writer)
|
"""
|
||||||
|
session = ControlSession(self, reader, writer)
|
||||||
|
await session.client_handshake()
|
||||||
|
await session.listen()
|
||||||
|
|
||||||
async def _serve_forever(self) -> None:
|
async def _serve_forever(self) -> None:
|
||||||
|
"""
|
||||||
|
To be run as an `asyncio.Task` by the following method.
|
||||||
|
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
|
||||||
|
method is called, when the former method ends for whatever reason.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with self._server:
|
async with self._server:
|
||||||
await self._server.serve_forever()
|
await self._server.serve_forever()
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
log.debug("%s stopped", self.__class__.__name__)
|
log.debug("%s stopped", self.__class__.__name__)
|
||||||
finally:
|
finally:
|
||||||
self.final_callback()
|
self._final_callback()
|
||||||
|
|
||||||
async def serve_forever(self) -> Task:
|
async def serve_forever(self) -> Task:
|
||||||
|
"""
|
||||||
|
This method actually starts the server and begins listening to client connections on the specified interface.
|
||||||
|
It should never block because the serving will be performed in a separate task.
|
||||||
|
"""
|
||||||
log.debug("Starting %s...", self.__class__.__name__)
|
log.debug("Starting %s...", self.__class__.__name__)
|
||||||
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||||
return create_task(self._serve_forever())
|
return create_task(self._serve_forever())
|
||||||
|
|
||||||
|
|
||||||
class UnixControlServer(ControlServer):
|
class TCPControlServer(ControlServer):
|
||||||
client_class = UnixControlClient
|
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
|
||||||
|
_client_class = TCPControlClient
|
||||||
|
|
||||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||||
|
self._host = server_kwargs.pop('host')
|
||||||
|
self._port = server_kwargs.pop('port')
|
||||||
|
super().__init__(pool, **server_kwargs)
|
||||||
|
|
||||||
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||||
|
server = await start_server(client_connected_cb, self._host, self._port, **kwargs)
|
||||||
|
log.debug("Opened socket at %s:%s", self._host, self._port)
|
||||||
|
return server
|
||||||
|
|
||||||
|
def _final_callback(self) -> None:
|
||||||
|
log.debug("Closed socket at %s:%s", self._host, self._port)
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlServer(ControlServer):
|
||||||
|
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||||
|
_client_class = UnixControlClient
|
||||||
|
|
||||||
|
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||||
|
from asyncio.streams import start_unix_server
|
||||||
|
self._start_unix_server = start_unix_server
|
||||||
self._socket_path = Path(server_kwargs.pop('path'))
|
self._socket_path = Path(server_kwargs.pop('path'))
|
||||||
super().__init__(pool, **server_kwargs)
|
super().__init__(pool, **server_kwargs)
|
||||||
|
|
||||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||||
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||||
return srv
|
return server
|
||||||
|
|
||||||
def final_callback(self) -> None:
|
def _final_callback(self) -> None:
|
||||||
|
"""Removes the unix socket on which the server was listening."""
|
||||||
self._socket_path.unlink()
|
self._socket_path.unlink()
|
||||||
log.debug("Removed socket '%s'", str(self._socket_path))
|
log.debug("Removed socket '%s'", str(self._socket_path))
|
||||||
|
304
src/asyncio_taskpool/session.py
Normal file
304
src/asyncio_taskpool/session.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
This module contains the the definition of the control session class used by the control server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from argparse import ArgumentError, HelpFormatter
|
||||||
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
|
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||||
|
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||||
|
from .helpers import get_first_doc_line, return_or_exception, tasks_str
|
||||||
|
from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||||
|
from .session_parser import CommandParser, NUM
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .server import ControlServer
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlSession:
|
||||||
|
"""
|
||||||
|
This class defines the API for controlling a task pool instance from the outside.
|
||||||
|
|
||||||
|
The commands received from a connected client are translated into method calls on the task pool instance.
|
||||||
|
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
Instantiation should happen once a client connection to the control server has already been established.
|
||||||
|
|
||||||
|
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
|
||||||
|
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
|
||||||
|
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server:
|
||||||
|
The instance of a `ControlServer` subclass starting the session.
|
||||||
|
reader:
|
||||||
|
The `asyncio.StreamReader` created when a client connected to the server.
|
||||||
|
writer:
|
||||||
|
The `asyncio.StreamWriter` created when a client connected to the server.
|
||||||
|
"""
|
||||||
|
self._control_server: 'ControlServer' = server
|
||||||
|
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
|
||||||
|
self._client_class_name = server.client_class_name
|
||||||
|
self._reader: StreamReader = reader
|
||||||
|
self._writer: StreamWriter = writer
|
||||||
|
self._parser: Optional[CommandParser] = None
|
||||||
|
self._subparsers = None
|
||||||
|
|
||||||
|
def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
|
||||||
|
**kwargs) -> CommandParser:
|
||||||
|
"""
|
||||||
|
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
|
||||||
|
|
||||||
|
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name:
|
||||||
|
The command name; passed directly into the `add_parser` method.
|
||||||
|
prog (optional):
|
||||||
|
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
|
||||||
|
equal to the `name` argument.
|
||||||
|
short_help (optional):
|
||||||
|
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
|
||||||
|
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
|
||||||
|
long_help (optional):
|
||||||
|
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
|
||||||
|
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
|
||||||
|
**kwargs (optional):
|
||||||
|
Any keyword-arguments to directly pass into the `add_parser` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the `CommandParser` class representing the newly added control command.
|
||||||
|
"""
|
||||||
|
if prog is None:
|
||||||
|
prog = name
|
||||||
|
kwargs.setdefault('help', short_help or long_help)
|
||||||
|
kwargs.setdefault('description', long_help or short_help)
|
||||||
|
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
|
||||||
|
|
||||||
|
def _add_base_commands(self) -> None:
|
||||||
|
"""
|
||||||
|
Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
|
||||||
|
|
||||||
|
These include commands mapping to the following pool methods:
|
||||||
|
- __str__
|
||||||
|
- pool_size (get/set property)
|
||||||
|
- num_running
|
||||||
|
"""
|
||||||
|
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
|
||||||
|
self._add_command(
|
||||||
|
CMD.POOL_SIZE,
|
||||||
|
short_help="Get/set the maximum number of tasks in the pool.",
|
||||||
|
formatter_class=HelpFormatter
|
||||||
|
).add_optional_num_argument(
|
||||||
|
default=None,
|
||||||
|
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
||||||
|
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
||||||
|
)
|
||||||
|
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
|
||||||
|
|
||||||
|
def _add_simple_commands(self) -> None:
|
||||||
|
"""
|
||||||
|
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
|
||||||
|
|
||||||
|
These include commands mapping to the following pool methods:
|
||||||
|
- start
|
||||||
|
- stop
|
||||||
|
- stop_all
|
||||||
|
- func_name
|
||||||
|
"""
|
||||||
|
self._add_command(
|
||||||
|
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
|
||||||
|
).add_optional_num_argument(
|
||||||
|
help="Number of tasks to start."
|
||||||
|
)
|
||||||
|
self._add_command(
|
||||||
|
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
|
||||||
|
).add_optional_num_argument(
|
||||||
|
help="Number of tasks to stop."
|
||||||
|
)
|
||||||
|
self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
|
||||||
|
self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
|
||||||
|
|
||||||
|
def _add_advanced_commands(self) -> None:
|
||||||
|
"""
|
||||||
|
Adds the commands that are only supported, if a `TaskPool` object is controlled.
|
||||||
|
|
||||||
|
These include commands mapping to the following pool methods:
|
||||||
|
- ...
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _init_parser(self, client_terminal_width: int) -> None:
|
||||||
|
"""
|
||||||
|
Initializes and fully configures the `CommandParser` responsible for handling the input.
|
||||||
|
|
||||||
|
Depending on what specific task pool class is controlled by the server, different commands are added.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_terminal_width:
|
||||||
|
The number of columns of the client's terminal to be able to nicely format messages from the parser.
|
||||||
|
"""
|
||||||
|
parser_kwargs = {
|
||||||
|
'prog': '',
|
||||||
|
SESSION_WRITER: self._writer,
|
||||||
|
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
|
||||||
|
}
|
||||||
|
self._parser = CommandParser(**parser_kwargs)
|
||||||
|
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
|
||||||
|
self._add_base_commands()
|
||||||
|
if isinstance(self._pool, TaskPool):
|
||||||
|
self._add_advanced_commands()
|
||||||
|
elif isinstance(self._pool, SimpleTaskPool):
|
||||||
|
self._add_simple_commands()
|
||||||
|
elif isinstance(self._pool, BaseTaskPool):
|
||||||
|
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
|
||||||
|
else:
|
||||||
|
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
|
||||||
|
|
||||||
|
async def client_handshake(self) -> None:
|
||||||
|
"""
|
||||||
|
This method must be invoked before starting any other client interaction.
|
||||||
|
|
||||||
|
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
|
||||||
|
"""
|
||||||
|
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
||||||
|
log.debug("%s connected", self._client_class_name)
|
||||||
|
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
|
||||||
|
self._writer.write(str(self._pool).encode())
|
||||||
|
await self._writer.drain()
|
||||||
|
|
||||||
|
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Acts as a wrapper around a call to a specific task pool method.
|
||||||
|
|
||||||
|
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
|
||||||
|
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
|
||||||
|
the exception caught is sent back.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func:
|
||||||
|
Reference to the task pool method.
|
||||||
|
*args (optional):
|
||||||
|
Any positional arguments to call the method with.
|
||||||
|
*+kwargs (optional):
|
||||||
|
Any keyword-arguments to call the method with.
|
||||||
|
"""
|
||||||
|
output = await return_or_exception(func, *args, **kwargs)
|
||||||
|
self._writer.write(b"ok" if output is None else str(output).encode())
|
||||||
|
|
||||||
|
async def _cmd_name(self, **_kwargs) -> None:
|
||||||
|
"""Maps to the `__str__` method of any task pool class."""
|
||||||
|
log.debug("%s requests task pool name", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.__str__, self._pool)
|
||||||
|
|
||||||
|
async def _cmd_pool_size(self, **kwargs) -> None:
|
||||||
|
"""Maps to the `pool_size` property of any task pool class."""
|
||||||
|
num = kwargs.get(NUM)
|
||||||
|
if num is None:
|
||||||
|
log.debug("%s requests pool size", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool)
|
||||||
|
else:
|
||||||
|
log.debug("%s requests setting pool size to %s", self._client_class_name, num)
|
||||||
|
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
|
||||||
|
|
||||||
|
async def _cmd_num_running(self, **_kwargs) -> None:
|
||||||
|
"""Maps to the `num_running` property of any task pool class."""
|
||||||
|
log.debug("%s requests number of running tasks", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
|
||||||
|
|
||||||
|
async def _cmd_start(self, **kwargs) -> None:
|
||||||
|
"""Maps to the `start` method of the `SimpleTaskPool` class."""
|
||||||
|
num = kwargs[NUM]
|
||||||
|
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
|
||||||
|
await self._write_function_output(self._pool.start, num)
|
||||||
|
|
||||||
|
async def _cmd_stop(self, **kwargs) -> None:
|
||||||
|
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
|
||||||
|
num = kwargs[NUM]
|
||||||
|
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
|
||||||
|
await self._write_function_output(self._pool.stop, num)
|
||||||
|
|
||||||
|
async def _cmd_stop_all(self, **_kwargs) -> None:
|
||||||
|
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
|
||||||
|
log.debug("%s requests stopping all tasks", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.stop_all)
|
||||||
|
|
||||||
|
async def _cmd_func_name(self, **_kwargs) -> None:
|
||||||
|
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
|
||||||
|
log.debug("%s requests pool function name", self._client_class_name)
|
||||||
|
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
|
||||||
|
|
||||||
|
async def _execute_command(self, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs:
|
||||||
|
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
|
||||||
|
simply passed into the method determined from the command name.
|
||||||
|
"""
|
||||||
|
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
|
||||||
|
await method(**kwargs)
|
||||||
|
|
||||||
|
async def _parse_command(self, msg: str) -> None:
|
||||||
|
"""
|
||||||
|
Takes a message from the client and attempts to parse it.
|
||||||
|
|
||||||
|
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
|
||||||
|
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
|
||||||
|
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg:
|
||||||
|
The non-empty string read from the client stream.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
||||||
|
except ArgumentError as e:
|
||||||
|
self._writer.write(str(e).encode())
|
||||||
|
return
|
||||||
|
except HelpRequested:
|
||||||
|
return
|
||||||
|
await self._execute_command(**kwargs)
|
||||||
|
|
||||||
|
async def listen(self) -> None:
|
||||||
|
"""
|
||||||
|
Enters the main control loop that only ends if either the server or the client disconnect.
|
||||||
|
|
||||||
|
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
|
||||||
|
This method should be called, when the client connection was established and the handshake was successful.
|
||||||
|
It will obviously block indefinitely.
|
||||||
|
"""
|
||||||
|
while self._control_server.is_serving():
|
||||||
|
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
||||||
|
if not msg:
|
||||||
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
|
break
|
||||||
|
await self._parse_command(msg)
|
||||||
|
await self._writer.drain()
|
127
src/asyncio_taskpool/session_parser.py
Normal file
127
src/asyncio_taskpool/session_parser.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
This module contains the the definition of the `CommandParser` class used in a control server session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter
|
||||||
|
from asyncio.streams import StreamWriter
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
from .constants import SESSION_WRITER, CLIENT_INFO
|
||||||
|
from .exceptions import HelpRequested
|
||||||
|
|
||||||
|
|
||||||
|
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||||
|
FORMATTER_CLASS = 'formatter_class'
|
||||||
|
NUM = 'num'
|
||||||
|
|
||||||
|
|
||||||
|
class CommandParser(ArgumentParser):
|
||||||
|
"""
|
||||||
|
Subclass of the standard `argparse.ArgumentParser` for remote interaction.
|
||||||
|
|
||||||
|
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
|
||||||
|
instance passed to it during initialization.
|
||||||
|
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
||||||
|
connected client.
|
||||||
|
Finally, it offers some convenience methods and makes use of custom exceptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
|
||||||
|
"""
|
||||||
|
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument.
|
||||||
|
|
||||||
|
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not
|
||||||
|
as convenient, when making use of sub-parsers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
terminal_width:
|
||||||
|
The number of columns of the terminal to which to adjust help formatting.
|
||||||
|
base_cls (optional):
|
||||||
|
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
|
||||||
|
"""
|
||||||
|
if base_cls is None:
|
||||||
|
base_cls = ArgumentDefaultsHelpFormatter
|
||||||
|
|
||||||
|
class ClientHelpFormatter(base_cls):
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
kwargs['width'] = terminal_width
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
return ClientHelpFormatter
|
||||||
|
|
||||||
|
def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Sets additional internal attributes depending on whether a parent-parser was defined.
|
||||||
|
|
||||||
|
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
|
||||||
|
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent (optional):
|
||||||
|
An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method
|
||||||
|
of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present,
|
||||||
|
the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`.
|
||||||
|
**kwargs(optional):
|
||||||
|
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
|
||||||
|
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
|
||||||
|
is empty.
|
||||||
|
"""
|
||||||
|
self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER)
|
||||||
|
self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH)
|
||||||
|
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
|
||||||
|
kwargs.setdefault('exit_on_error', False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_writer(self) -> StreamWriter:
|
||||||
|
"""Returns the predefined stream writer object of the control session."""
|
||||||
|
return self._session_writer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def terminal_width(self) -> int:
|
||||||
|
"""Returns the predefined terminal width."""
|
||||||
|
return self._terminal_width
|
||||||
|
|
||||||
|
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||||
|
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
|
||||||
|
if message:
|
||||||
|
self._session_writer.write(message.encode())
|
||||||
|
|
||||||
|
def exit(self, status: int = 0, message: str = None) -> None:
|
||||||
|
"""This is overridden to prevent system exit to be invoked."""
|
||||||
|
if message:
|
||||||
|
self._print_message(message)
|
||||||
|
|
||||||
|
def print_help(self, file=None) -> None:
|
||||||
|
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
|
||||||
|
super().print_help(file)
|
||||||
|
raise HelpRequested
|
||||||
|
|
||||||
|
def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action:
|
||||||
|
"""Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified."""
|
||||||
|
if not name_or_flags:
|
||||||
|
name_or_flags = (NUM, )
|
||||||
|
kwargs.setdefault('nargs', '?')
|
||||||
|
kwargs.setdefault('default', 1)
|
||||||
|
kwargs.setdefault('type', int)
|
||||||
|
return self.add_argument(*name_or_flags, **kwargs)
|
@ -20,6 +20,7 @@ Custom type definitions used in various modules.
|
|||||||
|
|
||||||
|
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
|
||||||
@ -28,10 +29,13 @@ T = TypeVar('T')
|
|||||||
ArgsT = Iterable[Any]
|
ArgsT = Iterable[Any]
|
||||||
KwArgsT = Mapping[str, Any]
|
KwArgsT = Mapping[str, Any]
|
||||||
|
|
||||||
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
|
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
||||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||||
|
|
||||||
EndCallbackT = Callable
|
EndCB = Callable
|
||||||
CancelCallbackT = Callable
|
CancelCB = Callable
|
||||||
|
|
||||||
|
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||||
|
|
||||||
|
PathT = Union[Path, str]
|
||||||
|
209
tests/test_client.py
Normal file
209
tests/test_client.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Unittests for the `asyncio_taskpool.client` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import IsolatedAsyncioTestCase, skipIf
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool import client
|
||||||
|
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
|
||||||
|
self.print_patcher = patch.object(client, 'print')
|
||||||
|
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||||
|
self.mock_print = self.print_patcher.start()
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.client = client.ControlClient(**self.kwargs)
|
||||||
|
|
||||||
|
self.mock_read = AsyncMock(return_value=FOO.encode())
|
||||||
|
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
|
||||||
|
self.mock_reader = MagicMock(read=self.mock_read)
|
||||||
|
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.abstract_patcher.stop()
|
||||||
|
self.print_patcher.stop()
|
||||||
|
|
||||||
|
def test_client_info(self):
|
||||||
|
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
|
||||||
|
client.ControlClient.client_info())
|
||||||
|
|
||||||
|
async def test_abstract(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
await self.client._open_connection(**self.kwargs)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.kwargs, self.client._conn_kwargs)
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, 'client_info')
|
||||||
|
async def test__server_handshake(self, mock_client_info: MagicMock):
|
||||||
|
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||||
|
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
mock_client_info.assert_called_once_with()
|
||||||
|
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
|
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
|
||||||
|
|
||||||
|
@patch.object(client, 'input')
|
||||||
|
def test__get_command(self, mock_input: MagicMock):
|
||||||
|
self.client._connected = True
|
||||||
|
|
||||||
|
mock_input.return_value = ' ' + FOO.upper() + ' '
|
||||||
|
mock_close = MagicMock()
|
||||||
|
mock_writer = MagicMock(close=mock_close)
|
||||||
|
output = self.client._get_command(mock_writer)
|
||||||
|
self.assertEqual(FOO, output)
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock_input.reset_mock()
|
||||||
|
mock_input.side_effect = KeyboardInterrupt
|
||||||
|
self.assertIsNone(self.client._get_command(mock_writer))
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock_input.reset_mock()
|
||||||
|
mock_input.side_effect = EOFError
|
||||||
|
self.assertIsNone(self.client._get_command(mock_writer))
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_called_once()
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, '_get_command')
|
||||||
|
async def test__interact(self, mock__get_command: MagicMock):
|
||||||
|
self.client._connected = True
|
||||||
|
|
||||||
|
mock__get_command.return_value = None
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_not_called()
|
||||||
|
self.mock_drain.assert_not_awaited()
|
||||||
|
self.mock_read.assert_not_awaited()
|
||||||
|
self.mock_print.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||||
|
self.mock_drain.side_effect = err = ConnectionError()
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
self.client._connected = True
|
||||||
|
self.mock_write.reset_mock()
|
||||||
|
self.mock_drain.reset_mock(side_effect=True)
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
|
self.mock_print.assert_called_once_with(FOO)
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, '_interact')
|
||||||
|
@patch.object(client.ControlClient, '_server_handshake')
|
||||||
|
@patch.object(client.ControlClient, '_open_connection')
|
||||||
|
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
|
||||||
|
mock__interact: AsyncMock):
|
||||||
|
mock__open_connection.return_value = None, None
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_not_awaited()
|
||||||
|
mock__interact.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
|
||||||
|
|
||||||
|
mock__open_connection.reset_mock()
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
mock__open_connection.return_value = self.mock_reader, self.mock_writer
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
mock__interact.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||||
|
|
||||||
|
mock__open_connection.reset_mock()
|
||||||
|
mock__server_handshake.reset_mock()
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
self.client._connected = True
|
||||||
|
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
|
||||||
|
mock__interact.side_effect = disconnect
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||||
|
|
||||||
|
|
||||||
|
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
|
||||||
|
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
|
||||||
|
self.mock_base_init = self.base_init_patcher.start()
|
||||||
|
self.path = '/tmp/asyncio_taskpool'
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.base_init_patcher.stop()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(Path(self.path), self.client._socket_path)
|
||||||
|
self.mock_base_init.assert_called_once_with(**self.kwargs)
|
||||||
|
|
||||||
|
@patch.object(client, 'print')
|
||||||
|
@patch.object(client, 'open_unix_connection')
|
||||||
|
async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock):
|
||||||
|
mock_open_unix_connection.return_value = expected_output = 'something'
|
||||||
|
kwargs = {'a': 1, 'b': 2}
|
||||||
|
output = await self.client._open_connection(**kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||||
|
mock_print.assert_not_called()
|
||||||
|
|
||||||
|
mock_open_unix_connection.reset_mock()
|
||||||
|
|
||||||
|
mock_open_unix_connection.side_effect = FileNotFoundError
|
||||||
|
output1, output2 = await self.client._open_connection(**kwargs)
|
||||||
|
self.assertIsNone(output1)
|
||||||
|
self.assertIsNone(output2)
|
||||||
|
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||||
|
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)
|
@ -86,3 +86,43 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_queue = MagicMock(join=mock_join)
|
mock_queue = MagicMock(join=mock_join)
|
||||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
self.assertIsNone(await helpers.join_queue(mock_queue))
|
||||||
mock_join.assert_awaited_once_with()
|
mock_join.assert_awaited_once_with()
|
||||||
|
|
||||||
|
def test_task_str(self):
|
||||||
|
self.assertEqual("task", helpers.tasks_str(1))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(0))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(-1))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(2))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(-10))
|
||||||
|
self.assertEqual("tasks", helpers.tasks_str(42))
|
||||||
|
|
||||||
|
def test_get_first_doc_line(self):
|
||||||
|
expected_output = 'foo bar baz'
|
||||||
|
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
||||||
|
something else
|
||||||
|
|
||||||
|
even more
|
||||||
|
""")
|
||||||
|
output = helpers.get_first_doc_line(mock_obj)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
async def test_return_or_exception(self):
|
||||||
|
expected_output = '420'
|
||||||
|
mock_func = AsyncMock(return_value=expected_output)
|
||||||
|
args = (1, 3, 5)
|
||||||
|
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
|
||||||
|
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_func.assert_awaited_once_with(*args, **kwargs)
|
||||||
|
|
||||||
|
mock_func = MagicMock(return_value=expected_output)
|
||||||
|
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_func.assert_called_once_with(*args, **kwargs)
|
||||||
|
|
||||||
|
class TestException(Exception):
|
||||||
|
pass
|
||||||
|
test_exception = TestException()
|
||||||
|
mock_func = MagicMock(side_effect=test_exception)
|
||||||
|
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
|
||||||
|
self.assertEqual(test_exception, output)
|
||||||
|
mock_func.assert_called_once_with(*args, **kwargs)
|
||||||
|
@ -18,19 +18,20 @@ __doc__ = """
|
|||||||
Unittests for the `asyncio_taskpool.pool` module.
|
Unittests for the `asyncio_taskpool.pool` module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.queues import Queue
|
from asyncio.locks import Semaphore
|
||||||
|
from asyncio.queues import QueueEmpty
|
||||||
|
from datetime import datetime
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from asyncio_taskpool import pool, exceptions
|
from asyncio_taskpool import pool, exceptions
|
||||||
|
from asyncio_taskpool.constants import DATETIME_FORMAT
|
||||||
|
|
||||||
|
|
||||||
EMPTY_LIST, EMPTY_DICT = [], {}
|
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
|
||||||
FOO, BAR = 'foo', 'bar'
|
FOO, BAR, BAZ = 'foo', 'bar', 'baz'
|
||||||
|
|
||||||
|
|
||||||
class TestException(Exception):
|
class TestException(Exception):
|
||||||
@ -45,19 +46,12 @@ class CommonTestCase(IsolatedAsyncioTestCase):
|
|||||||
task_pool: pool.BaseTaskPool
|
task_pool: pool.BaseTaskPool
|
||||||
log_lvl: int
|
log_lvl: int
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls) -> None:
|
|
||||||
cls.log_lvl = pool.log.level
|
|
||||||
pool.log.setLevel(999)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls) -> None:
|
|
||||||
pool.log.setLevel(cls.log_lvl)
|
|
||||||
|
|
||||||
def get_task_pool_init_params(self) -> dict:
|
def get_task_pool_init_params(self) -> dict:
|
||||||
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
|
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
self.log_lvl = pool.log.level
|
||||||
|
pool.log.setLevel(999)
|
||||||
self._pools = self.TEST_CLASS._pools
|
self._pools = self.TEST_CLASS._pools
|
||||||
# These three methods are called during initialization, so we mock them by default during setup:
|
# These three methods are called during initialization, so we mock them by default during setup:
|
||||||
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
|
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
|
||||||
@ -76,6 +70,7 @@ class CommonTestCase(IsolatedAsyncioTestCase):
|
|||||||
self._add_pool_patcher.stop()
|
self._add_pool_patcher.stop()
|
||||||
self.pool_size_patcher.stop()
|
self.pool_size_patcher.stop()
|
||||||
self.dunder_str_patcher.stop()
|
self.dunder_str_patcher.stop()
|
||||||
|
pool.log.setLevel(self.log_lvl)
|
||||||
|
|
||||||
|
|
||||||
class BaseTaskPoolTestCase(CommonTestCase):
|
class BaseTaskPoolTestCase(CommonTestCase):
|
||||||
@ -88,19 +83,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
|
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
|
self.assertEqual(0, self.task_pool._num_started)
|
||||||
|
self.assertEqual(0, self.task_pool._num_cancellations)
|
||||||
|
|
||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
self.assertEqual(0, self.task_pool._counter)
|
self.assertFalse(self.task_pool._closed)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
|
|
||||||
self.assertEqual(0, self.task_pool._num_cancelled)
|
|
||||||
self.assertEqual(0, self.task_pool._num_ended)
|
|
||||||
self.assertEqual(self.mock_idx, self.task_pool._idx)
|
|
||||||
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
||||||
|
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
|
|
||||||
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
||||||
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
|
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
|
||||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
|
|
||||||
|
self.assertEqual(self.mock_idx, self.task_pool._idx)
|
||||||
|
|
||||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||||
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
|
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
|
||||||
self.mock___str__.assert_called_once_with()
|
self.mock___str__.assert_called_once_with()
|
||||||
@ -143,26 +142,56 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
|
|
||||||
def test_num_running(self):
|
def test_num_running(self):
|
||||||
self.task_pool._running = ['foo', 'bar', 'baz']
|
self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
|
||||||
self.assertEqual(3, self.task_pool.num_running)
|
self.assertEqual(3, self.task_pool.num_running)
|
||||||
|
|
||||||
def test_num_cancelled(self):
|
def test_num_cancellations(self):
|
||||||
self.task_pool._num_cancelled = 3
|
self.task_pool._num_cancellations = 3
|
||||||
self.assertEqual(3, self.task_pool.num_cancelled)
|
self.assertEqual(3, self.task_pool.num_cancellations)
|
||||||
|
|
||||||
def test_num_ended(self):
|
def test_num_ended(self):
|
||||||
self.task_pool._num_ended = 3
|
self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
|
||||||
self.assertEqual(3, self.task_pool.num_ended)
|
self.assertEqual(3, self.task_pool.num_ended)
|
||||||
|
|
||||||
def test_num_finished(self):
|
def test_num_finished(self):
|
||||||
self.task_pool._num_cancelled = cancelled = 69
|
self.task_pool._num_cancellations = num_cancellations = 69
|
||||||
self.task_pool._num_ended = ended = 420
|
num_ended = 420
|
||||||
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
|
self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
|
||||||
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
|
self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
|
||||||
|
self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
|
||||||
|
|
||||||
def test_is_full(self):
|
def test_is_full(self):
|
||||||
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
|
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
|
||||||
|
|
||||||
|
def test_get_task_group_ids(self):
|
||||||
|
group_name, ids = 'abcdef', [1, 2, 3]
|
||||||
|
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
|
||||||
|
self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name))
|
||||||
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
|
self.task_pool.get_task_group_ids('something else')
|
||||||
|
|
||||||
|
async def test__check_start(self):
|
||||||
|
self.task_pool._closed = True
|
||||||
|
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
||||||
|
try:
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.task_pool._check_start(awaitable=None, function=None)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function)
|
||||||
|
with self.assertRaises(exceptions.NotCoroutine):
|
||||||
|
self.task_pool._check_start(awaitable=mock_coroutine_function, function=None)
|
||||||
|
with self.assertRaises(exceptions.NotCoroutine):
|
||||||
|
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
||||||
|
with self.assertRaises(exceptions.PoolIsClosed):
|
||||||
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
||||||
|
self.task_pool._closed = False
|
||||||
|
self.task_pool._locked = True
|
||||||
|
with self.assertRaises(exceptions.PoolIsLocked):
|
||||||
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
||||||
|
self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True))
|
||||||
|
finally:
|
||||||
|
await mock_coroutine
|
||||||
|
|
||||||
def test__task_name(self):
|
def test__task_name(self):
|
||||||
i = 123
|
i = 123
|
||||||
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
|
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
|
||||||
@ -171,12 +200,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||||
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
|
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
|
||||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||||
self.task_pool._num_cancelled = cancelled = 3
|
self.task_pool._num_cancellations = cancelled = 3
|
||||||
self.task_pool._running[task_id] = mock_task
|
self.task_pool._tasks_running[task_id] = mock_task
|
||||||
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
|
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
|
||||||
self.assertNotIn(task_id, self.task_pool._running)
|
self.assertNotIn(task_id, self.task_pool._tasks_running)
|
||||||
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
|
self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
|
||||||
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
|
self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
|
||||||
mock__task_name.assert_called_with(task_id)
|
mock__task_name.assert_called_with(task_id)
|
||||||
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||||
|
|
||||||
@ -184,15 +213,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||||
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
|
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
|
||||||
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
|
||||||
self.task_pool._num_ended = ended = 3
|
|
||||||
self.task_pool._enough_room._value = room = 123
|
self.task_pool._enough_room._value = room = 123
|
||||||
|
|
||||||
# End running task:
|
# End running task:
|
||||||
self.task_pool._running[task_id] = mock_task
|
self.task_pool._tasks_running[task_id] = mock_task
|
||||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||||
self.assertNotIn(task_id, self.task_pool._running)
|
self.assertNotIn(task_id, self.task_pool._tasks_running)
|
||||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
|
||||||
self.assertEqual(ended + 1, self.task_pool._num_ended)
|
|
||||||
self.assertEqual(room + 1, self.task_pool._enough_room._value)
|
self.assertEqual(room + 1, self.task_pool._enough_room._value)
|
||||||
mock__task_name.assert_called_with(task_id)
|
mock__task_name.assert_called_with(task_id)
|
||||||
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||||
@ -200,11 +227,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
mock_execute_optional.reset_mock()
|
mock_execute_optional.reset_mock()
|
||||||
|
|
||||||
# End cancelled task:
|
# End cancelled task:
|
||||||
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
|
self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id)
|
||||||
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
|
||||||
self.assertNotIn(task_id, self.task_pool._cancelled)
|
self.assertNotIn(task_id, self.task_pool._tasks_cancelled)
|
||||||
self.assertEqual(mock_task, self.task_pool._ended[task_id])
|
self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
|
||||||
self.assertEqual(ended + 2, self.task_pool._num_ended)
|
|
||||||
self.assertEqual(room + 2, self.task_pool._enough_room._value)
|
self.assertEqual(room + 2, self.task_pool._enough_room._value)
|
||||||
mock__task_name.assert_called_with(task_id)
|
mock__task_name.assert_called_with(task_id)
|
||||||
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
|
||||||
@ -246,92 +272,52 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
|
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
|
||||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||||
async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock,
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
mock_create_task: MagicMock):
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
def reset_mocks() -> None:
|
async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock,
|
||||||
mock__task_name.reset_mock()
|
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
|
||||||
mock__task_wrapper.reset_mock()
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
mock_create_task.reset_mock()
|
|
||||||
|
|
||||||
mock_create_task.return_value = mock_task = MagicMock()
|
mock_create_task.return_value = mock_task = MagicMock()
|
||||||
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
|
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
|
||||||
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
|
mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
|
||||||
self.task_pool._counter = count = 123
|
self.task_pool._num_started = count = 123
|
||||||
self.task_pool._enough_room._value = room = 123
|
self.task_pool._enough_room._value = room = 123
|
||||||
|
group_name, ignore_lock = 'testgroup', True
|
||||||
def check_nothing_changed() -> None:
|
output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock,
|
||||||
self.assertEqual(count, self.task_pool._counter)
|
|
||||||
self.assertNotIn(count, self.task_pool._running)
|
|
||||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
|
||||||
mock__task_name.assert_not_called()
|
|
||||||
mock__task_wrapper.assert_not_called()
|
|
||||||
mock_create_task.assert_not_called()
|
|
||||||
reset_mocks()
|
|
||||||
|
|
||||||
with self.assertRaises(exceptions.NotCoroutine):
|
|
||||||
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
|
||||||
check_nothing_changed()
|
|
||||||
|
|
||||||
self.task_pool._locked = True
|
|
||||||
ignore_closed = False
|
|
||||||
mock_awaitable = mock_coroutine()
|
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
|
||||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
|
||||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
||||||
await mock_awaitable
|
|
||||||
check_nothing_changed()
|
|
||||||
|
|
||||||
ignore_closed = True
|
|
||||||
mock_awaitable = mock_coroutine()
|
|
||||||
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
|
||||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
|
||||||
await mock_awaitable
|
|
||||||
self.assertEqual(count, output)
|
self.assertEqual(count, output)
|
||||||
self.assertEqual(count + 1, self.task_pool._counter)
|
mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock)
|
||||||
self.assertEqual(mock_task, self.task_pool._running[count])
|
|
||||||
self.assertEqual(room - 1, self.task_pool._enough_room._value)
|
self.assertEqual(room - 1, self.task_pool._enough_room._value)
|
||||||
|
self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name])
|
||||||
|
mock_reg_cls.assert_called_once_with()
|
||||||
|
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||||
|
mock_group_reg.add.assert_called_once_with(count)
|
||||||
mock__task_name.assert_called_once_with(count)
|
mock__task_name.assert_called_once_with(count)
|
||||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb)
|
||||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO)
|
||||||
reset_mocks()
|
self.assertEqual(mock_task, self.task_pool._tasks_running[count])
|
||||||
self.task_pool._counter = count
|
mock_group_reg.__aexit__.assert_awaited_once()
|
||||||
self.task_pool._enough_room._value = room
|
|
||||||
del self.task_pool._running[count]
|
|
||||||
|
|
||||||
mock_awaitable = mock_coroutine()
|
|
||||||
mock_create_task.side_effect = test_exception = TestException()
|
|
||||||
with self.assertRaises(TestException) as e:
|
|
||||||
await self.task_pool._start_task(mock_awaitable, ignore_closed,
|
|
||||||
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
|
|
||||||
self.assertEqual(test_exception, e)
|
|
||||||
await mock_awaitable
|
|
||||||
self.assertEqual(count + 1, self.task_pool._counter)
|
|
||||||
self.assertNotIn(count, self.task_pool._running)
|
|
||||||
self.assertEqual(room, self.task_pool._enough_room._value)
|
|
||||||
mock__task_name.assert_called_once_with(count)
|
|
||||||
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
|
|
||||||
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
|
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
|
||||||
def test__get_running_task(self, mock__task_name: MagicMock):
|
def test__get_running_task(self, mock__task_name: MagicMock):
|
||||||
task_id, mock_task = 555, MagicMock()
|
task_id, mock_task = 555, MagicMock()
|
||||||
self.task_pool._running[task_id] = mock_task
|
self.task_pool._tasks_running[task_id] = mock_task
|
||||||
output = self.task_pool._get_running_task(task_id)
|
output = self.task_pool._get_running_task(task_id)
|
||||||
self.assertEqual(mock_task, output)
|
self.assertEqual(mock_task, output)
|
||||||
|
|
||||||
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
|
self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id)
|
||||||
with self.assertRaises(exceptions.AlreadyCancelled):
|
with self.assertRaises(exceptions.AlreadyCancelled):
|
||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_called_once_with(task_id)
|
mock__task_name.assert_called_once_with(task_id)
|
||||||
mock__task_name.reset_mock()
|
mock__task_name.reset_mock()
|
||||||
|
|
||||||
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
|
self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id)
|
||||||
with self.assertRaises(exceptions.TaskEnded):
|
with self.assertRaises(exceptions.TaskEnded):
|
||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_called_once_with(task_id)
|
mock__task_name.assert_called_once_with(task_id)
|
||||||
mock__task_name.reset_mock()
|
mock__task_name.reset_mock()
|
||||||
|
|
||||||
del self.task_pool._ended[task_id]
|
del self.task_pool._tasks_ended[task_id]
|
||||||
with self.assertRaises(exceptions.InvalidTaskID):
|
with self.assertRaises(exceptions.InvalidTaskID):
|
||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_not_called()
|
mock__task_name.assert_not_called()
|
||||||
@ -344,263 +330,416 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
||||||
|
|
||||||
def test_cancel_all(self):
|
def test__cancel_and_remove_all_from_group(self):
|
||||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
task_id = 555
|
||||||
self.task_pool._running = {1: mock_task1, 2: mock_task2}
|
mock_cancel = MagicMock()
|
||||||
assert not self.task_pool._interrupt_flag.is_set()
|
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
|
||||||
self.assertIsNone(self.task_pool.cancel_all(FOO))
|
|
||||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
class MockRegister(set, MagicMock):
|
||||||
mock_task1.cancel.assert_called_once_with(msg=FOO)
|
pass
|
||||||
mock_task2.cancel.assert_called_once_with(msg=FOO)
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
|
||||||
|
mock_cancel.assert_called_once_with(msg=FOO)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
|
async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
||||||
|
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
|
||||||
|
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
|
||||||
|
self.task_pool._task_groups[FOO] = mock_group_reg
|
||||||
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
|
await self.task_pool.cancel_group(BAR)
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_not_called()
|
||||||
|
mock_grp_aenter.assert_not_called()
|
||||||
|
mock_grp_aexit.assert_not_called()
|
||||||
|
self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
|
||||||
|
mock_grp_aenter.assert_awaited_once_with()
|
||||||
|
mock_grp_aexit.assert_awaited_once()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
|
async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
||||||
|
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
|
||||||
|
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
|
||||||
|
self.task_pool._task_groups[BAR] = mock_group_reg
|
||||||
|
self.assertIsNone(await self.task_pool.cancel_all(FOO))
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
|
||||||
|
mock_grp_aenter.assert_awaited_once_with()
|
||||||
|
mock_grp_aexit.assert_awaited_once()
|
||||||
|
|
||||||
async def test_flush(self):
|
async def test_flush(self):
|
||||||
test_exception = TestException()
|
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
self.task_pool._tasks_ended = {123: mock_ended_func()}
|
||||||
self.task_pool._ended = {123: mock_ended_func()}
|
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
|
||||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
|
||||||
self.task_pool._interrupt_flag.set()
|
mock_ended_func.assert_awaited_once_with()
|
||||||
output = await self.task_pool.flush(return_exceptions=True)
|
mock_cancelled_func.assert_awaited_once_with()
|
||||||
self.assertListEqual([FOO, test_exception], output)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
|
||||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
|
||||||
|
|
||||||
self.task_pool._ended = {123: mock_ended_func()}
|
async def test_gather_and_close(self):
|
||||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
|
||||||
output = await self.task_pool.flush(return_exceptions=True)
|
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||||
self.assertListEqual([FOO, test_exception], output)
|
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
|
||||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
|
||||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
|
||||||
|
self.task_pool._tasks_running = running = {789: mock_running_func()}
|
||||||
|
|
||||||
async def test_gather(self):
|
|
||||||
test_exception = TestException()
|
|
||||||
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
|
|
||||||
mock_running_func = AsyncMock(return_value=BAR)
|
|
||||||
mock_queue_join = AsyncMock()
|
|
||||||
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
|
|
||||||
self.task_pool._ended = ended = {123: mock_ended_func()}
|
|
||||||
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
|
|
||||||
self.task_pool._running = running = {789: mock_running_func()}
|
|
||||||
self.task_pool._interrupt_flag.set()
|
|
||||||
|
|
||||||
assert not self.task_pool._locked
|
|
||||||
with self.assertRaises(exceptions.PoolStillUnlocked):
|
with self.assertRaises(exceptions.PoolStillUnlocked):
|
||||||
await self.task_pool.gather()
|
await self.task_pool.gather_and_close()
|
||||||
self.assertDictEqual(self.task_pool._ended, ended)
|
self.assertDictEqual(ended, self.task_pool._tasks_ended)
|
||||||
self.assertDictEqual(self.task_pool._cancelled, cancelled)
|
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(self.task_pool._running, running)
|
self.assertDictEqual(running, self.task_pool._tasks_running)
|
||||||
self.assertListEqual(self.task_pool._before_gathering, before_gather)
|
self.assertListEqual(before_gather, self.task_pool._before_gathering)
|
||||||
self.assertTrue(self.task_pool._interrupt_flag.is_set())
|
self.assertFalse(self.task_pool._closed)
|
||||||
|
|
||||||
self.task_pool._locked = True
|
self.task_pool._locked = True
|
||||||
|
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||||
def check_assertions(output) -> None:
|
mock_before_gather.assert_awaited_once_with()
|
||||||
self.assertListEqual([FOO, test_exception, BAR], output)
|
mock_ended_func.assert_awaited_once_with()
|
||||||
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
|
mock_cancelled_func.assert_awaited_once_with()
|
||||||
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
|
mock_running_func.assert_awaited_once_with()
|
||||||
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertFalse(self.task_pool._interrupt_flag.is_set())
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
|
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
|
||||||
check_assertions(await self.task_pool.gather(return_exceptions=True))
|
self.assertTrue(self.task_pool._closed)
|
||||||
|
|
||||||
self.task_pool._before_gathering = [mock_queue_join()]
|
|
||||||
self.task_pool._ended = {123: mock_ended_func()}
|
|
||||||
self.task_pool._cancelled = {456: mock_cancelled_func()}
|
|
||||||
self.task_pool._running = {789: mock_running_func()}
|
|
||||||
check_assertions(await self.task_pool.gather(return_exceptions=True))
|
|
||||||
|
|
||||||
|
|
||||||
class TaskPoolTestCase(CommonTestCase):
|
class TaskPoolTestCase(CommonTestCase):
|
||||||
TEST_CLASS = pool.TaskPool
|
TEST_CLASS = pool.TaskPool
|
||||||
task_pool: pool.TaskPool
|
task_pool: pool.TaskPool
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_start_task')
|
def setUp(self) -> None:
|
||||||
async def test__apply_one(self, mock__start_task: AsyncMock):
|
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
||||||
mock__start_task.return_value = expected_output = 12345
|
self.base_class_init = self.base_class_init_patcher.start()
|
||||||
mock_awaitable = MagicMock()
|
super().setUp()
|
||||||
mock_func = MagicMock(return_value=mock_awaitable)
|
|
||||||
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
def tearDown(self) -> None:
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
self.base_class_init_patcher.stop()
|
||||||
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
|
super().tearDown()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||||
|
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
|
||||||
|
|
||||||
|
def test__cancel_group_meta_tasks(self):
|
||||||
|
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||||
|
self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2}
|
||||||
|
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO))
|
||||||
|
self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running)
|
||||||
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
mock_task1.cancel.assert_not_called()
|
||||||
|
mock_task2.cancel.assert_not_called()
|
||||||
|
|
||||||
|
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR))
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||||
|
self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled)
|
||||||
|
mock_task1.cancel.assert_called_once_with()
|
||||||
|
mock_task2.cancel.assert_called_once_with()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
|
@patch.object(pool.TaskPool, '_cancel_group_meta_tasks')
|
||||||
|
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock,
|
||||||
|
mock_base__cancel_and_remove_all_from_group: MagicMock):
|
||||||
|
group_name, group_reg, msg = 'xyz', MagicMock(), FOO
|
||||||
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg))
|
||||||
|
mock__cancel_group_meta_tasks.assert_called_once_with(group_name)
|
||||||
|
mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, 'cancel_group')
|
||||||
|
async def test_cancel_group(self, mock_base_cancel_group: AsyncMock):
|
||||||
|
group_name, msg = 'abc', 'xyz'
|
||||||
|
await self.task_pool.cancel_group(group_name, msg=msg)
|
||||||
|
mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, 'cancel_all')
|
||||||
|
async def test_cancel_all(self, mock_base_cancel_all: AsyncMock):
|
||||||
|
msg = 'xyz'
|
||||||
|
await self.task_pool.cancel_all(msg=msg)
|
||||||
|
mock_base_cancel_all.assert_awaited_once_with(msg=msg)
|
||||||
|
|
||||||
|
def test__pop_ended_meta_tasks(self):
|
||||||
|
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
|
||||||
|
self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1}
|
||||||
|
mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True)
|
||||||
|
self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3}
|
||||||
|
expected_output = {mock_done_task1, mock_done_task2, mock_done_task3}
|
||||||
|
output = self.task_pool._pop_ended_meta_tasks()
|
||||||
|
self.assertSetEqual(expected_output, output)
|
||||||
|
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
|
||||||
|
|
||||||
|
@patch.object(pool.TaskPool, '_pop_ended_meta_tasks')
|
||||||
|
@patch.object(pool.BaseTaskPool, 'flush')
|
||||||
|
async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock):
|
||||||
|
mock_ended_meta_task = AsyncMock()
|
||||||
|
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
|
||||||
|
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
||||||
|
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
||||||
|
self.assertIsNone(await self.task_pool.flush(return_exceptions=False))
|
||||||
|
mock_base_flush.assert_awaited_once_with(return_exceptions=False)
|
||||||
|
mock__pop_ended_meta_tasks.assert_called_once_with()
|
||||||
|
mock_ended_meta_task.assert_awaited_once_with()
|
||||||
|
mock_cancelled_meta_task.assert_awaited_once_with()
|
||||||
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, 'gather_and_close')
|
||||||
|
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
|
||||||
|
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
|
||||||
|
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
|
||||||
|
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
||||||
|
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
||||||
|
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||||
|
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
|
||||||
|
mock_meta_task1.assert_awaited_once_with()
|
||||||
|
mock_meta_task2.assert_awaited_once_with()
|
||||||
|
mock_cancelled_meta_task.assert_awaited_once_with()
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||||
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
|
||||||
|
@patch.object(pool, 'datetime')
|
||||||
|
def test__generate_group_name(self, mock_datetime: MagicMock):
|
||||||
|
prefix, func = 'x y z', AsyncMock(__name__=BAR)
|
||||||
|
dt = datetime(1776, 7, 4, 0, 0, 1)
|
||||||
|
mock_datetime.now = MagicMock(return_value=dt)
|
||||||
|
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
|
||||||
|
output = pool.TaskPool._generate_group_name(prefix, func)
|
||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
mock_func.assert_called_once_with(*args, **kwargs)
|
|
||||||
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
|
@patch.object(pool.TaskPool, '_start_task')
|
||||||
|
async def test__apply_num(self, mock__start_task: AsyncMock):
|
||||||
|
group_name = FOO + BAR
|
||||||
|
mock_awaitable = object()
|
||||||
|
mock_func = MagicMock(return_value=mock_awaitable)
|
||||||
|
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
|
||||||
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
|
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
|
||||||
|
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
|
||||||
|
mock__start_task.assert_has_awaits(3 * [
|
||||||
|
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
])
|
||||||
|
|
||||||
mock_func.reset_mock()
|
mock_func.reset_mock()
|
||||||
mock__start_task.reset_mock()
|
mock__start_task.reset_mock()
|
||||||
|
|
||||||
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
|
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
|
||||||
self.assertEqual(expected_output, output)
|
mock_func.assert_has_calls(num * [call(*args)])
|
||||||
mock_func.assert_called_once_with(*args)
|
mock__start_task.assert_has_awaits(num * [
|
||||||
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
|
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
])
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_apply_one')
|
@patch.object(pool, 'create_task')
|
||||||
async def test_apply(self, mock__apply_one: AsyncMock):
|
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
|
||||||
mock__apply_one.return_value = mock_id = 67890
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
mock_func, num = MagicMock(), 3
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
|
async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
|
||||||
|
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
|
||||||
|
mock__generate_group_name.return_value = generated_name = 'name 123'
|
||||||
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
|
mock__apply_num.return_value = mock_apply_coroutine = object()
|
||||||
|
mock_task_future = AsyncMock()
|
||||||
|
mock_create_task.return_value = mock_task_future()
|
||||||
|
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
|
||||||
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
expected_output = num * [mock_id]
|
self.task_pool._task_groups = {}
|
||||||
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
|
|
||||||
self.assertEqual(expected_output, output)
|
|
||||||
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
|
|
||||||
|
|
||||||
async def test__queue_producer(self):
|
def check_assertions(_group_name, _output):
|
||||||
|
self.assertEqual(_group_name, _output)
|
||||||
|
mock__check_start.assert_called_once_with(function=mock_func)
|
||||||
|
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
|
||||||
|
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||||
|
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
|
||||||
|
mock_create_task.assert_called_once_with(mock_apply_coroutine)
|
||||||
|
mock_group_reg.__aexit__.assert_awaited_once()
|
||||||
|
mock_task_future.assert_awaited_once_with()
|
||||||
|
|
||||||
|
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
||||||
|
check_assertions(group_name, output)
|
||||||
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
|
mock__check_start.reset_mock()
|
||||||
|
self.task_pool._task_groups.clear()
|
||||||
|
mock_group_reg.__aenter__.reset_mock()
|
||||||
|
mock__apply_num.reset_mock()
|
||||||
|
mock_create_task.reset_mock()
|
||||||
|
mock_group_reg.__aexit__.reset_mock()
|
||||||
|
mock_task_future = AsyncMock()
|
||||||
|
mock_create_task.return_value = mock_task_future()
|
||||||
|
|
||||||
|
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
|
||||||
|
check_assertions(generated_name, output)
|
||||||
|
mock__generate_group_name.assert_called_once_with('apply', mock_func)
|
||||||
|
|
||||||
|
@patch.object(pool, 'Queue')
|
||||||
|
async def test__queue_producer(self, mock_queue_cls: MagicMock):
|
||||||
mock_put = AsyncMock()
|
mock_put = AsyncMock()
|
||||||
mock_q = MagicMock(put=mock_put)
|
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
|
||||||
args = (FOO, BAR, 123)
|
item1, item2, item3 = FOO, 420, 69
|
||||||
assert not self.task_pool._interrupt_flag.is_set()
|
arg_iter = iter([item1, item2, item3])
|
||||||
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
|
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
||||||
mock_put.assert_has_awaits([call(arg) for arg in args])
|
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
next(arg_iter)
|
||||||
|
|
||||||
mock_put.reset_mock()
|
mock_put.reset_mock()
|
||||||
self.task_pool._interrupt_flag.set()
|
|
||||||
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
|
|
||||||
mock_put.assert_not_awaited()
|
|
||||||
|
|
||||||
@patch.object(pool, 'partial')
|
mock_put.side_effect = [CancelledError, None]
|
||||||
@patch.object(pool, 'star_function')
|
arg_iter = iter([item1, item2, item3])
|
||||||
@patch.object(pool.TaskPool, '_start_task')
|
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
|
||||||
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
|
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
||||||
mock_partial: MagicMock):
|
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
||||||
mock_partial.return_value = queue_callback = 'not really'
|
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
|
||||||
mock_star_function.return_value = awaitable = 'totally an awaitable'
|
mock_queue.item_processed.assert_has_calls([call(), call()])
|
||||||
q, arg = Queue(), 420.69
|
self.assertListEqual([item2, item3], list(arg_iter))
|
||||||
q.put_nowait(arg)
|
|
||||||
mock_func, stars = MagicMock(), 3
|
|
||||||
mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
|
|
||||||
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
|
|
||||||
self.assertTrue(q.empty())
|
|
||||||
mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
|
|
||||||
end_callback=queue_callback, cancel_callback=cancel_cb)
|
|
||||||
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
|
|
||||||
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
|
|
||||||
q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars,
|
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
|
||||||
mock__start_task.reset_mock()
|
|
||||||
mock_star_function.reset_mock()
|
|
||||||
mock_partial.reset_mock()
|
|
||||||
|
|
||||||
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
|
|
||||||
self.assertTrue(q.empty())
|
|
||||||
mock__start_task.assert_not_awaited()
|
|
||||||
mock_star_function.assert_not_called()
|
|
||||||
mock_partial.assert_not_called()
|
|
||||||
|
|
||||||
@patch.object(pool, 'execute_optional')
|
@patch.object(pool, 'execute_optional')
|
||||||
@patch.object(pool.TaskPool, '_queue_consumer')
|
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
|
||||||
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
|
semaphore, mock_end_cb = Semaphore(1), MagicMock()
|
||||||
task_id, mock_q = 420, MagicMock()
|
wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb)
|
||||||
mock_func, stars = MagicMock(), 3
|
task_id = 1234
|
||||||
mock_wait = AsyncMock()
|
await wrapped(task_id)
|
||||||
mock_flag = MagicMock(wait=mock_wait)
|
self.assertEqual(2, semaphore._value)
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,))
|
||||||
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars,
|
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb))
|
@patch.object(pool, 'star_function')
|
||||||
mock_wait.assert_awaited_once_with()
|
@patch.object(pool.TaskPool, '_start_task')
|
||||||
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars,
|
@patch.object(pool, 'Semaphore')
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
@patch.object(pool.TaskPool, '_get_map_end_callback')
|
||||||
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
|
async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
|
||||||
|
mock__start_task: AsyncMock, mock_star_function: MagicMock):
|
||||||
|
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||||
|
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
|
||||||
|
mock_star_function.return_value = awaitable = 'totally an awaitable'
|
||||||
|
arg1, arg2 = 123456789, 'function argument'
|
||||||
|
mock_q_maxsize = 3
|
||||||
|
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
|
||||||
|
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
|
||||||
|
group_name, mock_func, stars = 'whatever', MagicMock(), 3
|
||||||
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
|
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
|
||||||
|
self.assertTrue(semaphore.locked())
|
||||||
|
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||||
|
mock__start_task.assert_has_awaits(2 * [
|
||||||
|
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||||
|
])
|
||||||
|
mock_star_function.assert_has_calls([
|
||||||
|
call(mock_func, arg1, arg_stars=stars),
|
||||||
|
call(mock_func, arg2, arg_stars=stars)
|
||||||
|
])
|
||||||
|
|
||||||
@patch.object(pool, 'iter')
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool, 'join_queue', new_callable=MagicMock)
|
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
|
||||||
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
|
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
|
||||||
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
|
@patch.object(pool, 'join_queue', new_callable=MagicMock)
|
||||||
mock_create_task: MagicMock, mock_iter: MagicMock):
|
@patch.object(pool, 'Queue')
|
||||||
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
mock_join_queue.return_value = mock_join = 'awaitable'
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
mock_iter.return_value = args_iter = iter(args)
|
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
|
||||||
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
|
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
|
||||||
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
|
mock_create_task: MagicMock):
|
||||||
self.assertIsInstance(output_q, Queue)
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
self.assertEqual(num_tasks, output_q.qsize())
|
mock_queue_cls.return_value = mock_q = MagicMock()
|
||||||
for arg in args[:num_tasks]:
|
mock_join_queue.return_value = fake_join = object()
|
||||||
self.assertEqual(arg, output_q.get_nowait())
|
mock__queue_producer.return_value = fake_producer = object()
|
||||||
self.assertTrue(output_q.empty())
|
mock__queue_consumer.return_value = fake_consumer = object()
|
||||||
for arg in args[num_tasks:]:
|
fake_task1, fake_task2 = object(), object()
|
||||||
self.assertEqual(arg, next(args_iter))
|
mock_create_task.side_effect = [fake_task1, fake_task2]
|
||||||
with self.assertRaises(StopIteration):
|
|
||||||
next(args_iter)
|
|
||||||
self.assertListEqual([mock_join], self.task_pool._before_gathering)
|
|
||||||
mock_join_queue.assert_called_once_with(output_q)
|
|
||||||
mock__queue_producer.assert_called_once_with(output_q, args_iter)
|
|
||||||
mock_create_task.assert_called_once_with(mock_producer_coro)
|
|
||||||
|
|
||||||
self.task_pool._before_gathering.clear()
|
group_name, group_size = 'onetwothree', 0
|
||||||
mock_join_queue.reset_mock()
|
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
|
||||||
mock__queue_producer.reset_mock()
|
|
||||||
mock_create_task.reset_mock()
|
|
||||||
|
|
||||||
num_tasks = 6
|
|
||||||
mock_iter.return_value = args_iter = iter(args)
|
|
||||||
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
|
|
||||||
self.assertIsInstance(output_q, Queue)
|
|
||||||
self.assertEqual(len(args), output_q.qsize())
|
|
||||||
for arg in args:
|
|
||||||
self.assertEqual(arg, output_q.get_nowait())
|
|
||||||
self.assertTrue(output_q.empty())
|
|
||||||
with self.assertRaises(StopIteration):
|
|
||||||
next(args_iter)
|
|
||||||
self.assertListEqual([mock_join], self.task_pool._before_gathering)
|
|
||||||
mock_join_queue.assert_called_once_with(output_q)
|
|
||||||
mock__queue_producer.assert_not_called()
|
|
||||||
mock_create_task.assert_not_called()
|
|
||||||
|
|
||||||
@patch.object(pool, 'Event')
|
|
||||||
@patch.object(pool.TaskPool, '_queue_consumer')
|
|
||||||
@patch.object(pool.TaskPool, '_set_up_args_queue')
|
|
||||||
async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
|
|
||||||
mock_event_cls: MagicMock):
|
|
||||||
qsize = 4
|
|
||||||
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
|
|
||||||
mock_flag_set = MagicMock()
|
|
||||||
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
|
|
||||||
|
|
||||||
mock_func, stars = MagicMock(), 3
|
|
||||||
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
|
|
||||||
self.task_pool._locked = False
|
with self.assertRaises(ValueError):
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
||||||
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
mock__set_up_args_queue.assert_not_called()
|
|
||||||
mock__queue_consumer.assert_not_awaited()
|
|
||||||
mock_flag_set.assert_not_called()
|
|
||||||
|
|
||||||
self.task_pool._locked = True
|
mock__check_start.reset_mock()
|
||||||
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
|
|
||||||
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
|
group_size = 1234
|
||||||
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars,
|
self.task_pool._task_groups = {group_name: MagicMock()}
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)])
|
|
||||||
mock_flag_set.assert_called_once_with()
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
|
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
||||||
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
|
|
||||||
|
mock__check_start.reset_mock()
|
||||||
|
|
||||||
|
self.task_pool._task_groups.clear()
|
||||||
|
self.task_pool._before_gathering = []
|
||||||
|
|
||||||
|
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
|
||||||
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
|
mock_reg_cls.assert_called_once_with()
|
||||||
|
self.task_pool._task_groups[group_name] = mock_group_reg
|
||||||
|
mock_group_reg.__aenter__.assert_awaited_once_with()
|
||||||
|
mock_queue_cls.assert_called_once_with(maxsize=group_size)
|
||||||
|
mock_join_queue.assert_called_once_with(mock_q)
|
||||||
|
self.assertListEqual([fake_join], self.task_pool._before_gathering)
|
||||||
|
mock__queue_producer.assert_called_once()
|
||||||
|
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
|
||||||
|
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
|
||||||
|
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
|
||||||
|
mock_group_reg.__aexit__.assert_awaited_once()
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
async def test_map(self, mock__map: AsyncMock):
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
|
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||||
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
|
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
|
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
|
||||||
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
|
self.assertEqual(group_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
|
mock__map.reset_mock()
|
||||||
|
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
|
||||||
|
self.assertEqual(generated_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
|
||||||
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_called_once_with('map', mock_func)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
async def test_starmap(self, mock__map: AsyncMock):
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
|
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||||
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
args_iter, num_tasks = ([FOO], [BAR]), 2
|
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
|
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
|
||||||
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
|
self.assertEqual(group_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
|
mock__map.reset_mock()
|
||||||
|
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
|
||||||
|
self.assertEqual(generated_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
|
||||||
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
async def test_doublestarmap(self, mock__map: AsyncMock):
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
|
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
||||||
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
|
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
|
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
|
||||||
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
|
self.assertEqual(group_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
|
mock__map.reset_mock()
|
||||||
|
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
|
||||||
|
self.assertEqual(generated_name, output)
|
||||||
|
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
|
||||||
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
|
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
|
||||||
|
|
||||||
|
|
||||||
class SimpleTaskPoolTestCase(CommonTestCase):
|
class SimpleTaskPoolTestCase(CommonTestCase):
|
||||||
@ -667,7 +806,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
def test_stop(self, mock_cancel: MagicMock):
|
def test_stop(self, mock_cancel: MagicMock):
|
||||||
num = 2
|
num = 2
|
||||||
id1, id2, id3 = 5, 6, 7
|
id1, id2, id3 = 5, 6, 7
|
||||||
self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR}
|
self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR}
|
||||||
output = self.task_pool.stop(num)
|
output = self.task_pool.stop(num)
|
||||||
expected_output = [id3, id2]
|
expected_output = [id3, id2]
|
||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
@ -689,3 +828,10 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
mock_num_running.assert_called_once_with()
|
mock_num_running.assert_called_once_with()
|
||||||
mock_stop.assert_called_once_with(num)
|
mock_stop.assert_called_once_with(num)
|
||||||
|
|
||||||
|
|
||||||
|
def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock:
|
||||||
|
mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock()
|
||||||
|
mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit,
|
||||||
|
add=mock_grp_add)
|
||||||
|
return mock_group_reg
|
||||||
|
166
tests/test_server.py
Normal file
166
tests/test_server.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Unittests for the `asyncio_taskpool.server` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import IsolatedAsyncioTestCase, skipIf
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool import server
|
||||||
|
from asyncio_taskpool.client import ControlClient, UnixControlClient
|
||||||
|
|
||||||
|
|
||||||
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = server.log.level
|
||||||
|
server.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
server.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
|
||||||
|
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||||
|
self.mock_pool = MagicMock()
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.server = server.ControlServer(pool=self.mock_pool, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.abstract_patcher.stop()
|
||||||
|
|
||||||
|
def test_client_class_name(self):
|
||||||
|
self.assertEqual(ControlClient.__name__, server.ControlServer.client_class_name)
|
||||||
|
|
||||||
|
async def test_abstract(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
args = [AsyncMock()]
|
||||||
|
await self.server._get_server_instance(*args)
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
self.server._final_callback()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.mock_pool, self.server._pool)
|
||||||
|
self.assertEqual(self.kwargs, self.server._server_kwargs)
|
||||||
|
self.assertIsNone(self.server._server)
|
||||||
|
|
||||||
|
def test_pool(self):
|
||||||
|
self.assertEqual(self.mock_pool, self.server.pool)
|
||||||
|
|
||||||
|
def test_is_serving(self):
|
||||||
|
self.server._server = MagicMock(is_serving=MagicMock(return_value=FOO + BAR))
|
||||||
|
self.assertEqual(FOO + BAR, self.server.is_serving())
|
||||||
|
|
||||||
|
@patch.object(server, 'ControlSession')
|
||||||
|
async def test__client_connected_cb(self, mock_client_session_cls: MagicMock):
|
||||||
|
mock_client_handshake, mock_listen = AsyncMock(), AsyncMock()
|
||||||
|
mock_client_session_cls.return_value = MagicMock(client_handshake=mock_client_handshake, listen=mock_listen)
|
||||||
|
mock_reader, mock_writer = MagicMock(), MagicMock()
|
||||||
|
self.assertIsNone(await self.server._client_connected_cb(mock_reader, mock_writer))
|
||||||
|
mock_client_session_cls.assert_called_once_with(self.server, mock_reader, mock_writer)
|
||||||
|
mock_client_handshake.assert_awaited_once_with()
|
||||||
|
mock_listen.assert_awaited_once_with()
|
||||||
|
|
||||||
|
@patch.object(server.ControlServer, '_final_callback')
|
||||||
|
async def test__serve_forever(self, mock__final_callback: MagicMock):
|
||||||
|
mock_aenter, mock_serve_forever = AsyncMock(), AsyncMock(side_effect=asyncio.CancelledError)
|
||||||
|
self.server._server = MagicMock(__aenter__=mock_aenter, serve_forever=mock_serve_forever)
|
||||||
|
with self.assertLogs(server.log, logging.DEBUG):
|
||||||
|
self.assertIsNone(await self.server._serve_forever())
|
||||||
|
mock_aenter.assert_awaited_once_with()
|
||||||
|
mock_serve_forever.assert_awaited_once_with()
|
||||||
|
mock__final_callback.assert_called_once_with()
|
||||||
|
|
||||||
|
mock_aenter.reset_mock()
|
||||||
|
mock_serve_forever.reset_mock(side_effect=True)
|
||||||
|
mock__final_callback.reset_mock()
|
||||||
|
|
||||||
|
self.assertIsNone(await self.server._serve_forever())
|
||||||
|
mock_aenter.assert_awaited_once_with()
|
||||||
|
mock_serve_forever.assert_awaited_once_with()
|
||||||
|
mock__final_callback.assert_called_once_with()
|
||||||
|
|
||||||
|
@patch.object(server, 'create_task')
|
||||||
|
@patch.object(server.ControlServer, '_serve_forever', new_callable=MagicMock())
|
||||||
|
@patch.object(server.ControlServer, '_get_server_instance')
|
||||||
|
async def test_serve_forever(self, mock__get_server_instance: AsyncMock, mock__serve_forever: MagicMock,
|
||||||
|
mock_create_task: MagicMock):
|
||||||
|
mock__serve_forever.return_value = mock_awaitable = 'some_coroutine'
|
||||||
|
mock_create_task.return_value = expected_output = 12345
|
||||||
|
output = await self.server.serve_forever()
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock__get_server_instance.assert_awaited_once_with(self.server._client_connected_cb, **self.kwargs)
|
||||||
|
mock__serve_forever.assert_called_once_with()
|
||||||
|
mock_create_task.assert_called_once_with(mock_awaitable)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
|
||||||
|
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = server.log.level
|
||||||
|
server.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
server.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
|
||||||
|
self.mock_base_init = self.base_init_patcher.start()
|
||||||
|
self.mock_pool = MagicMock()
|
||||||
|
self.path = '/tmp/asyncio_taskpool'
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.base_init_patcher.stop()
|
||||||
|
|
||||||
|
def test__client_class(self):
|
||||||
|
self.assertEqual(UnixControlClient, self.server._client_class)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(Path(self.path), self.server._socket_path)
|
||||||
|
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
|
||||||
|
|
||||||
|
@patch.object(server, 'start_unix_server')
|
||||||
|
async def test__get_server_instance(self, mock_start_unix_server: AsyncMock):
|
||||||
|
mock_start_unix_server.return_value = expected_output = 'totally_a_server'
|
||||||
|
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
|
||||||
|
args = [mock_callback]
|
||||||
|
output = await self.server._get_server_instance(*args, **mock_kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_start_unix_server.assert_called_once_with(mock_callback, Path(self.path), **mock_kwargs)
|
||||||
|
|
||||||
|
def test__final_callback(self):
|
||||||
|
self.server._socket_path = MagicMock()
|
||||||
|
self.assertIsNone(self.server._final_callback())
|
||||||
|
self.server._socket_path.unlink.assert_called_once_with()
|
324
tests/test_session.py
Normal file
324
tests/test_session.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Unittests for the `asyncio_taskpool.session` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
from argparse import ArgumentError, Namespace
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||||
|
|
||||||
|
from asyncio_taskpool import session
|
||||||
|
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER
|
||||||
|
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||||
|
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
|
||||||
|
|
||||||
|
|
||||||
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = session.log.level
|
||||||
|
session.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
session.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
|
||||||
|
self.mock_client_class_name = FOO + BAR
|
||||||
|
self.mock_server = MagicMock(pool=self.mock_pool,
|
||||||
|
client_class_name=self.mock_client_class_name)
|
||||||
|
self.mock_reader = MagicMock()
|
||||||
|
self.mock_writer = MagicMock()
|
||||||
|
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.mock_server, self.session._control_server)
|
||||||
|
self.assertEqual(self.mock_pool, self.session._pool)
|
||||||
|
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
|
||||||
|
self.assertEqual(self.mock_reader, self.session._reader)
|
||||||
|
self.assertEqual(self.mock_writer, self.session._writer)
|
||||||
|
self.assertIsNone(self.session._parser)
|
||||||
|
self.assertIsNone(self.session._subparsers)
|
||||||
|
|
||||||
|
def test__add_command(self):
|
||||||
|
expected_output = 123456
|
||||||
|
mock_add_parser = MagicMock(return_value=expected_output)
|
||||||
|
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
|
||||||
|
self.session._parser = MagicMock()
|
||||||
|
name, prog, short_help, long_help = 'abc', None, 'short123', None
|
||||||
|
kwargs = {'x': 1, 'y': 2}
|
||||||
|
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
|
||||||
|
parent=self.session._parser, **kwargs)
|
||||||
|
|
||||||
|
mock_add_parser.reset_mock()
|
||||||
|
|
||||||
|
prog, long_help = 'ffffff', 'so long, wow'
|
||||||
|
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
|
||||||
|
parent=self.session._parser, **kwargs)
|
||||||
|
|
||||||
|
mock_add_parser.reset_mock()
|
||||||
|
|
||||||
|
short_help = None
|
||||||
|
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
|
||||||
|
parent=self.session._parser, **kwargs)
|
||||||
|
|
||||||
|
@patch.object(session, 'get_first_doc_line')
|
||||||
|
@patch.object(session.ControlSession, '_add_command')
|
||||||
|
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
|
||||||
|
self.assertIsNone(self.session._add_base_commands())
|
||||||
|
mock__add_command.assert_called()
|
||||||
|
mock_get_first_doc_line.assert_called()
|
||||||
|
|
||||||
|
mock__add_command.reset_mock()
|
||||||
|
mock_get_first_doc_line.reset_mock()
|
||||||
|
|
||||||
|
self.assertIsNone(self.session._add_simple_commands())
|
||||||
|
mock__add_command.assert_called()
|
||||||
|
mock_get_first_doc_line.assert_called()
|
||||||
|
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
self.session._add_advanced_commands()
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_add_simple_commands')
|
||||||
|
@patch.object(session.ControlSession, '_add_advanced_commands')
|
||||||
|
@patch.object(session.ControlSession, '_add_base_commands')
|
||||||
|
@patch.object(session, 'CommandParser')
|
||||||
|
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
|
||||||
|
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
|
||||||
|
mock_command_parser_cls.return_value = mock_parser = MagicMock()
|
||||||
|
self.session._pool = TaskPool()
|
||||||
|
width = 1234
|
||||||
|
expected_parser_kwargs = {
|
||||||
|
'prog': '',
|
||||||
|
SESSION_WRITER: self.mock_writer,
|
||||||
|
CLIENT_INFO.TERMINAL_WIDTH: width,
|
||||||
|
}
|
||||||
|
self.assertIsNone(self.session._init_parser(width))
|
||||||
|
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
|
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||||
|
mock__add_base_commands.assert_called_once_with()
|
||||||
|
mock__add_advanced_commands.assert_called_once_with()
|
||||||
|
mock__add_simple_commands.assert_not_called()
|
||||||
|
|
||||||
|
mock_command_parser_cls.reset_mock()
|
||||||
|
mock_parser.add_subparsers.reset_mock()
|
||||||
|
mock__add_base_commands.reset_mock()
|
||||||
|
mock__add_advanced_commands.reset_mock()
|
||||||
|
mock__add_simple_commands.reset_mock()
|
||||||
|
|
||||||
|
async def fake_coroutine(): pass
|
||||||
|
|
||||||
|
self.session._pool = SimpleTaskPool(fake_coroutine)
|
||||||
|
self.assertIsNone(self.session._init_parser(width))
|
||||||
|
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
|
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||||
|
mock__add_base_commands.assert_called_once_with()
|
||||||
|
mock__add_advanced_commands.assert_not_called()
|
||||||
|
mock__add_simple_commands.assert_called_once_with()
|
||||||
|
|
||||||
|
mock_command_parser_cls.reset_mock()
|
||||||
|
mock_parser.add_subparsers.reset_mock()
|
||||||
|
mock__add_base_commands.reset_mock()
|
||||||
|
mock__add_advanced_commands.reset_mock()
|
||||||
|
mock__add_simple_commands.reset_mock()
|
||||||
|
|
||||||
|
class FakeTaskPool(BaseTaskPool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.session._pool = FakeTaskPool()
|
||||||
|
with self.assertRaises(UnknownTaskPoolClass):
|
||||||
|
self.session._init_parser(width)
|
||||||
|
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
|
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||||
|
mock__add_base_commands.assert_called_once_with()
|
||||||
|
mock__add_advanced_commands.assert_not_called()
|
||||||
|
mock__add_simple_commands.assert_not_called()
|
||||||
|
|
||||||
|
mock_command_parser_cls.reset_mock()
|
||||||
|
mock_parser.add_subparsers.reset_mock()
|
||||||
|
mock__add_base_commands.reset_mock()
|
||||||
|
mock__add_advanced_commands.reset_mock()
|
||||||
|
mock__add_simple_commands.reset_mock()
|
||||||
|
|
||||||
|
self.session._pool = MagicMock()
|
||||||
|
with self.assertRaises(NotATaskPool):
|
||||||
|
self.session._init_parser(width)
|
||||||
|
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
|
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
|
||||||
|
mock__add_base_commands.assert_called_once_with()
|
||||||
|
mock__add_advanced_commands.assert_not_called()
|
||||||
|
mock__add_simple_commands.assert_not_called()
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_init_parser')
|
||||||
|
async def test_client_handshake(self, mock__init_parser: MagicMock):
|
||||||
|
width = 5678
|
||||||
|
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
||||||
|
mock_read = AsyncMock(return_value=msg.encode())
|
||||||
|
self.mock_reader.read = mock_read
|
||||||
|
self.mock_writer.drain = AsyncMock()
|
||||||
|
self.assertIsNone(await self.session.client_handshake())
|
||||||
|
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
|
mock__init_parser.assert_called_once_with(width)
|
||||||
|
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
|
||||||
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
|
@patch.object(session, 'return_or_exception')
|
||||||
|
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
|
||||||
|
self.mock_writer.write = MagicMock()
|
||||||
|
mock_return_or_exception.return_value = None
|
||||||
|
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
|
||||||
|
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
|
||||||
|
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
|
||||||
|
self.mock_writer.write.assert_called_once_with(b"ok")
|
||||||
|
|
||||||
|
mock_return_or_exception.reset_mock()
|
||||||
|
self.mock_writer.write.reset_mock()
|
||||||
|
|
||||||
|
mock_return_or_exception.return_value = output = MagicMock()
|
||||||
|
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
|
||||||
|
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
|
||||||
|
self.mock_writer.write.assert_called_once_with(str(output).encode())
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
|
||||||
|
self.assertIsNone(await self.session._cmd_name())
|
||||||
|
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
|
||||||
|
num = 12345
|
||||||
|
kwargs = {session.NUM: num, FOO: BAR}
|
||||||
|
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
|
||||||
|
mock__write_function_output.assert_awaited_once_with(
|
||||||
|
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
|
||||||
|
)
|
||||||
|
|
||||||
|
mock__write_function_output.reset_mock()
|
||||||
|
|
||||||
|
kwargs.pop(session.NUM)
|
||||||
|
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
|
||||||
|
mock__write_function_output.assert_awaited_once_with(
|
||||||
|
self.mock_pool.__class__.pool_size.fget, self.session._pool
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
|
||||||
|
self.assertIsNone(await self.session._cmd_num_running())
|
||||||
|
mock__write_function_output.assert_awaited_once_with(
|
||||||
|
self.mock_pool.__class__.num_running.fget, self.session._pool
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
|
||||||
|
num = 12345
|
||||||
|
kwargs = {session.NUM: num, FOO: BAR}
|
||||||
|
self.assertIsNone(await self.session._cmd_start(**kwargs))
|
||||||
|
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
|
||||||
|
num = 12345
|
||||||
|
kwargs = {session.NUM: num, FOO: BAR}
|
||||||
|
self.assertIsNone(await self.session._cmd_stop(**kwargs))
|
||||||
|
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
|
||||||
|
self.assertIsNone(await self.session._cmd_stop_all())
|
||||||
|
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_write_function_output')
|
||||||
|
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
|
||||||
|
self.assertIsNone(await self.session._cmd_func_name())
|
||||||
|
mock__write_function_output.assert_awaited_once_with(
|
||||||
|
self.mock_pool.__class__.func_name.fget, self.session._pool
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test__execute_command(self):
|
||||||
|
mock_method = AsyncMock()
|
||||||
|
cmd = 'this-is-a-test'
|
||||||
|
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
|
||||||
|
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||||
|
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
|
||||||
|
mock_method.assert_awaited_once_with(**kwargs)
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_execute_command')
|
||||||
|
async def test__parse_command(self, mock__execute_command: AsyncMock):
|
||||||
|
msg = 'asdf asd as a'
|
||||||
|
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||||
|
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
|
||||||
|
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
||||||
|
self.mock_writer.write = MagicMock()
|
||||||
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
|
self.mock_writer.write.assert_not_called()
|
||||||
|
mock__execute_command.assert_awaited_once_with(**kwargs)
|
||||||
|
|
||||||
|
mock__execute_command.reset_mock()
|
||||||
|
mock_parse_args.reset_mock()
|
||||||
|
|
||||||
|
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
||||||
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
|
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
||||||
|
mock__execute_command.assert_not_awaited()
|
||||||
|
|
||||||
|
self.mock_writer.write.reset_mock()
|
||||||
|
mock_parse_args.reset_mock()
|
||||||
|
|
||||||
|
mock_parse_args.side_effect = HelpRequested()
|
||||||
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
|
self.mock_writer.write.assert_not_called()
|
||||||
|
mock__execute_command.assert_not_awaited()
|
||||||
|
|
||||||
|
@patch.object(session.ControlSession, '_parse_command')
|
||||||
|
async def test_listen(self, mock__parse_command: AsyncMock):
|
||||||
|
def make_reader_return_empty():
|
||||||
|
self.mock_reader.read.return_value = b''
|
||||||
|
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
||||||
|
msg = "fascinating"
|
||||||
|
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
|
||||||
|
self.assertIsNone(await self.session.listen())
|
||||||
|
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
|
||||||
|
mock__parse_command.assert_awaited_once_with(msg)
|
||||||
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
|
self.mock_reader.read.reset_mock()
|
||||||
|
mock__parse_command.reset_mock()
|
||||||
|
self.mock_writer.drain.reset_mock()
|
||||||
|
|
||||||
|
self.mock_server.is_serving = MagicMock(return_value=False)
|
||||||
|
self.assertIsNone(await self.session.listen())
|
||||||
|
self.mock_reader.read.assert_not_awaited()
|
||||||
|
mock__parse_command.assert_not_awaited()
|
||||||
|
self.mock_writer.drain.assert_not_awaited()
|
134
tests/test_session_parser.py
Normal file
134
tests/test_session_parser.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Unittests for the `asyncio_taskpool.session_parser` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool import session_parser
|
||||||
|
from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO
|
||||||
|
from asyncio_taskpool.exceptions import HelpRequested
|
||||||
|
|
||||||
|
|
||||||
|
FOO = 'foo'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory')
|
||||||
|
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
||||||
|
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
||||||
|
self.session_writer, self.terminal_width = MagicMock(), 420
|
||||||
|
self.kwargs = {
|
||||||
|
SESSION_WRITER: self.session_writer,
|
||||||
|
CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width,
|
||||||
|
session_parser.FORMATTER_CLASS: FOO
|
||||||
|
}
|
||||||
|
self.parser = session_parser.CommandParser(**self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.help_formatter_factory_patcher.stop()
|
||||||
|
|
||||||
|
def test_help_formatter_factory(self):
|
||||||
|
self.help_formatter_factory_patcher.stop()
|
||||||
|
|
||||||
|
class MockBaseClass(HelpFormatter):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
terminal_width = 123456789
|
||||||
|
cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass)
|
||||||
|
self.assertTrue(issubclass(cls, MockBaseClass))
|
||||||
|
instance = cls('prog')
|
||||||
|
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||||
|
|
||||||
|
cls = session_parser.CommandParser.help_formatter_factory(terminal_width)
|
||||||
|
self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter))
|
||||||
|
instance = cls('prog')
|
||||||
|
self.assertEqual(terminal_width, getattr(instance, '_width'))
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertIsInstance(self.parser, ArgumentParser)
|
||||||
|
self.assertEqual(self.session_writer, self.parser._session_writer)
|
||||||
|
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
||||||
|
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
||||||
|
self.assertFalse(getattr(self.parser, 'exit_on_error'))
|
||||||
|
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
||||||
|
|
||||||
|
def test_session_writer(self):
|
||||||
|
self.assertEqual(self.session_writer, self.parser.session_writer)
|
||||||
|
|
||||||
|
def test_terminal_width(self):
|
||||||
|
self.assertEqual(self.terminal_width, self.parser.terminal_width)
|
||||||
|
|
||||||
|
def test__print_message(self):
|
||||||
|
self.session_writer.write = MagicMock()
|
||||||
|
self.assertIsNone(self.parser._print_message(''))
|
||||||
|
self.session_writer.write.assert_not_called()
|
||||||
|
msg = 'foo bar baz'
|
||||||
|
self.assertIsNone(self.parser._print_message(msg))
|
||||||
|
self.session_writer.write.assert_called_once_with(msg.encode())
|
||||||
|
|
||||||
|
@patch.object(session_parser.CommandParser, '_print_message')
|
||||||
|
def test_exit(self, mock__print_message: MagicMock):
|
||||||
|
self.assertIsNone(self.parser.exit(123, ''))
|
||||||
|
mock__print_message.assert_not_called()
|
||||||
|
msg = 'foo bar baz'
|
||||||
|
self.assertIsNone(self.parser.exit(123, msg))
|
||||||
|
mock__print_message.assert_called_once_with(msg)
|
||||||
|
|
||||||
|
@patch.object(session_parser.ArgumentParser, 'print_help')
|
||||||
|
def test_print_help(self, mock_print_help: MagicMock):
|
||||||
|
arg = MagicMock()
|
||||||
|
with self.assertRaises(HelpRequested):
|
||||||
|
self.parser.print_help(arg)
|
||||||
|
mock_print_help.assert_called_once_with(arg)
|
||||||
|
|
||||||
|
def test_add_optional_num_argument(self):
|
||||||
|
metavar = 'FOOBAR'
|
||||||
|
action = self.parser.add_optional_num_argument(metavar=metavar)
|
||||||
|
self.assertIsInstance(action, Action)
|
||||||
|
self.assertEqual('?', action.nargs)
|
||||||
|
self.assertEqual(1, action.default)
|
||||||
|
self.assertEqual(int, action.type)
|
||||||
|
self.assertEqual(metavar, action.metavar)
|
||||||
|
num = 111
|
||||||
|
kwargs = vars(self.parser.parse_args([f'{num}']))
|
||||||
|
self.assertDictEqual({session_parser.NUM: num}, kwargs)
|
||||||
|
|
||||||
|
name = f'--{FOO}'
|
||||||
|
nargs = '+'
|
||||||
|
default = 1
|
||||||
|
_type = float
|
||||||
|
required = True
|
||||||
|
dest = 'foo_bar'
|
||||||
|
action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type,
|
||||||
|
required=required, metavar=metavar, dest=dest)
|
||||||
|
self.assertIsInstance(action, Action)
|
||||||
|
self.assertEqual(nargs, action.nargs)
|
||||||
|
self.assertEqual(default, action.default)
|
||||||
|
self.assertEqual(_type, action.type)
|
||||||
|
self.assertEqual(required, action.required)
|
||||||
|
self.assertEqual(metavar, action.metavar)
|
||||||
|
self.assertEqual(dest, action.dest)
|
||||||
|
kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5']))
|
||||||
|
self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs)
|
207
usage/USAGE.md
207
usage/USAGE.md
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
## Minimal example for `SimpleTaskPool`
|
## Minimal example for `SimpleTaskPool`
|
||||||
|
|
||||||
|
With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run.
|
||||||
|
|
||||||
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
|
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
|
||||||
|
|
||||||
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
|
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
|
||||||
@ -28,18 +30,18 @@ async def work(n: int) -> None:
|
|||||||
"""
|
"""
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print("did", i)
|
print("> did", i)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
|
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
|
||||||
await pool.start(3) # launches work tasks 0, 1, and 2
|
await pool.start(3) # launches work tasks 0, 1, and 2
|
||||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||||
await pool.start() # launches work task 3
|
await pool.start() # launches work task 3
|
||||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||||
pool.stop(2) # cancels tasks 3 and 2
|
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
|
||||||
pool.lock() # required for the last line
|
pool.lock() # required for the last line
|
||||||
await pool.gather() # awaits all tasks, then flushes the pool
|
await pool.gather_and_close() # awaits all tasks, then flushes the pool
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -52,29 +54,29 @@ SimpleTaskPool-0 initialized
|
|||||||
Started SimpleTaskPool-0_Task-0
|
Started SimpleTaskPool-0_Task-0
|
||||||
Started SimpleTaskPool-0_Task-1
|
Started SimpleTaskPool-0_Task-1
|
||||||
Started SimpleTaskPool-0_Task-2
|
Started SimpleTaskPool-0_Task-2
|
||||||
did 0
|
> did 0
|
||||||
did 0
|
> did 0
|
||||||
did 0
|
> did 0
|
||||||
Started SimpleTaskPool-0_Task-3
|
Started SimpleTaskPool-0_Task-3
|
||||||
did 1
|
> did 1
|
||||||
did 1
|
> did 1
|
||||||
did 1
|
> did 1
|
||||||
did 0
|
> did 0
|
||||||
|
> did 2
|
||||||
|
> did 2
|
||||||
SimpleTaskPool-0 is locked!
|
SimpleTaskPool-0 is locked!
|
||||||
Cancelling SimpleTaskPool-0_Task-3 ...
|
|
||||||
Cancelled SimpleTaskPool-0_Task-3
|
|
||||||
Ended SimpleTaskPool-0_Task-3
|
|
||||||
Cancelling SimpleTaskPool-0_Task-2 ...
|
Cancelling SimpleTaskPool-0_Task-2 ...
|
||||||
Cancelled SimpleTaskPool-0_Task-2
|
Cancelled SimpleTaskPool-0_Task-2
|
||||||
Ended SimpleTaskPool-0_Task-2
|
Ended SimpleTaskPool-0_Task-2
|
||||||
did 2
|
Cancelling SimpleTaskPool-0_Task-3 ...
|
||||||
did 2
|
Cancelled SimpleTaskPool-0_Task-3
|
||||||
did 3
|
Ended SimpleTaskPool-0_Task-3
|
||||||
did 3
|
> did 3
|
||||||
|
> did 3
|
||||||
Ended SimpleTaskPool-0_Task-0
|
Ended SimpleTaskPool-0_Task-0
|
||||||
Ended SimpleTaskPool-0_Task-1
|
Ended SimpleTaskPool-0_Task-1
|
||||||
did 4
|
> did 4
|
||||||
did 4
|
> did 4
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced example for `TaskPool`
|
## Advanced example for `TaskPool`
|
||||||
@ -101,43 +103,41 @@ async def work(start: int, stop: int, step: int = 1) -> None:
|
|||||||
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
|
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
|
||||||
for i in range(start, stop, step):
|
for i in range(start, stop, step):
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print("work with", i)
|
print("> work with", i)
|
||||||
|
|
||||||
|
|
||||||
async def other_work(a: int, b: int) -> None:
|
async def other_work(a: int, b: int) -> None:
|
||||||
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
|
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
|
||||||
for i in range(a, b):
|
for i in range(a, b):
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
print("other_work with", i)
|
print("> other_work with", i)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
# Initialize a new task pool instance and limit its size to 3 tasks.
|
# Initialize a new task pool instance and limit its size to 3 tasks.
|
||||||
pool = TaskPool(3)
|
pool = TaskPool(3)
|
||||||
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
|
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
|
||||||
print("Called `apply`")
|
print("> Called `apply`")
|
||||||
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
|
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
|
||||||
# Let the tasks work for a bit.
|
# Let the tasks work for a bit.
|
||||||
await asyncio.sleep(1.5)
|
await asyncio.sleep(1.5)
|
||||||
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
|
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
|
||||||
# positional arguments by using `starmap`, but have **no more than two of those** run concurrently.
|
# positional arguments by using `starmap`, but we want no more than two of those to run concurrently.
|
||||||
# Since we set our pool size to 3, and already have two tasks working within the pool,
|
# Since we set our pool size to 3, and already have two tasks working within the pool,
|
||||||
# only the first one of these will start immediately (and receive ID 2).
|
# only the first one of these will start immediately (and receive ID 2).
|
||||||
# The second one will start (with ID 3), only once there is room in the pool,
|
# The second one will start (with ID 3), only once there is room in the pool,
|
||||||
# which -- in this example -- will be the case after ID 2 ends;
|
# which -- in this example -- will be the case after ID 2 ends.
|
||||||
# until then the `starmap` method call **will block**!
|
|
||||||
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
||||||
# **only** once there is room in the pool **and** no more than one of these last four tasks is running.
|
# only once there is room in the pool and no more than one other task of these new ones is running.
|
||||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||||
print("Calling `starmap`...")
|
await pool.starmap(other_work, args_list, group_size=2)
|
||||||
await pool.starmap(other_work, args_list, num_tasks=2)
|
print("> Called `starmap`")
|
||||||
print("`starmap` returned")
|
|
||||||
# Now we lock the pool, so that we can safely await all our tasks.
|
# Now we lock the pool, so that we can safely await all our tasks.
|
||||||
pool.lock()
|
pool.lock()
|
||||||
# Finally, we block, until all tasks have ended.
|
# Finally, we block, until all tasks have ended.
|
||||||
print("Called `gather`")
|
print("> Calling `gather_and_close`...")
|
||||||
await pool.gather()
|
await pool.gather_and_close()
|
||||||
print("Done.")
|
print("> Done.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -152,82 +152,81 @@ Additional comments for the output are provided with `<---` next to the output l
|
|||||||
TaskPool-0 initialized
|
TaskPool-0 initialized
|
||||||
Started TaskPool-0_Task-0
|
Started TaskPool-0_Task-0
|
||||||
Started TaskPool-0_Task-1
|
Started TaskPool-0_Task-1
|
||||||
Called `apply`
|
> Called `apply`
|
||||||
work with 100
|
> work with 100
|
||||||
work with 100
|
> work with 100
|
||||||
Calling `starmap`... <--- notice that this blocks as expected
|
> Called `starmap` <--- notice that this immediately returns, even before Task-2 is started
|
||||||
Started TaskPool-0_Task-2
|
> Calling `gather_and_close`... <--- this blocks `main()` until all tasks have ended
|
||||||
work with 110
|
|
||||||
work with 110
|
|
||||||
other_work with 0
|
|
||||||
other_work with 1
|
|
||||||
work with 120
|
|
||||||
work with 120
|
|
||||||
other_work with 2
|
|
||||||
other_work with 3
|
|
||||||
work with 130
|
|
||||||
work with 130
|
|
||||||
other_work with 4
|
|
||||||
other_work with 5
|
|
||||||
work with 140
|
|
||||||
work with 140
|
|
||||||
other_work with 6
|
|
||||||
other_work with 7
|
|
||||||
work with 150
|
|
||||||
work with 150
|
|
||||||
other_work with 8
|
|
||||||
Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
|
|
||||||
TaskPool-0 is locked!
|
TaskPool-0 is locked!
|
||||||
|
Started TaskPool-0_Task-2 <--- at this point the pool is full
|
||||||
|
> work with 110
|
||||||
|
> work with 110
|
||||||
|
> other_work with 0
|
||||||
|
> other_work with 1
|
||||||
|
> work with 120
|
||||||
|
> work with 120
|
||||||
|
> other_work with 2
|
||||||
|
> other_work with 3
|
||||||
|
> work with 130
|
||||||
|
> work with 130
|
||||||
|
> other_work with 4
|
||||||
|
> other_work with 5
|
||||||
|
> work with 140
|
||||||
|
> work with 140
|
||||||
|
> other_work with 6
|
||||||
|
> other_work with 7
|
||||||
|
> work with 150
|
||||||
|
> work with 150
|
||||||
|
> other_work with 8
|
||||||
|
Ended TaskPool-0_Task-2 <--- this frees up room for one more task from `starmap`
|
||||||
Started TaskPool-0_Task-3
|
Started TaskPool-0_Task-3
|
||||||
other_work with 9
|
> other_work with 9
|
||||||
`starmap` returned
|
> work with 160
|
||||||
Called `gather` <--- now this will block `main()` until all tasks have ended
|
> work with 160
|
||||||
work with 160
|
> other_work with 10
|
||||||
work with 160
|
> other_work with 11
|
||||||
other_work with 10
|
> work with 170
|
||||||
other_work with 11
|
> work with 170
|
||||||
work with 170
|
> other_work with 12
|
||||||
work with 170
|
> other_work with 13
|
||||||
other_work with 12
|
> work with 180
|
||||||
other_work with 13
|
> work with 180
|
||||||
work with 180
|
> other_work with 14
|
||||||
work with 180
|
> other_work with 15
|
||||||
other_work with 14
|
|
||||||
other_work with 15
|
|
||||||
Ended TaskPool-0_Task-0
|
Ended TaskPool-0_Task-0
|
||||||
Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start
|
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
|
||||||
Started TaskPool-0_Task-4
|
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
|
||||||
work with 190
|
> work with 190
|
||||||
work with 190
|
> work with 190
|
||||||
other_work with 16
|
> other_work with 16
|
||||||
other_work with 20
|
> other_work with 17
|
||||||
other_work with 17
|
> other_work with 20
|
||||||
other_work with 21
|
> other_work with 18
|
||||||
other_work with 18
|
> other_work with 21
|
||||||
other_work with 22
|
Ended TaskPool-0_Task-3 <--- now that only Task-4 of the group remains, Task-5 starts
|
||||||
other_work with 19
|
|
||||||
Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start
|
|
||||||
Started TaskPool-0_Task-5
|
Started TaskPool-0_Task-5
|
||||||
other_work with 23
|
> other_work with 19
|
||||||
other_work with 30
|
> other_work with 22
|
||||||
other_work with 24
|
> other_work with 23
|
||||||
other_work with 31
|
> other_work with 30
|
||||||
other_work with 25
|
> other_work with 24
|
||||||
other_work with 32
|
> other_work with 31
|
||||||
other_work with 26
|
> other_work with 25
|
||||||
other_work with 33
|
> other_work with 32
|
||||||
other_work with 27
|
> other_work with 26
|
||||||
other_work with 34
|
> other_work with 33
|
||||||
other_work with 28
|
> other_work with 27
|
||||||
other_work with 35
|
> other_work with 34
|
||||||
|
> other_work with 28
|
||||||
|
> other_work with 35
|
||||||
|
> other_work with 29
|
||||||
|
> other_work with 36
|
||||||
Ended TaskPool-0_Task-4
|
Ended TaskPool-0_Task-4
|
||||||
other_work with 29
|
> other_work with 37
|
||||||
other_work with 36
|
> other_work with 38
|
||||||
other_work with 37
|
> other_work with 39
|
||||||
other_work with 38
|
|
||||||
other_work with 39
|
|
||||||
Done.
|
|
||||||
Ended TaskPool-0_Task-5
|
Ended TaskPool-0_Task-5
|
||||||
|
> Done.
|
||||||
```
|
```
|
||||||
|
|
||||||
© 2022 Daniil Fajnberg
|
© 2022 Daniil Fajnberg
|
||||||
|
@ -23,7 +23,7 @@ Use the main CLI client to interface at the socket.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from asyncio_taskpool import SimpleTaskPool, UnixControlServer
|
from asyncio_taskpool import SimpleTaskPool, TCPControlServer
|
||||||
from asyncio_taskpool.constants import PACKAGE_NAME
|
from asyncio_taskpool.constants import PACKAGE_NAME
|
||||||
|
|
||||||
|
|
||||||
@ -34,11 +34,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
|
|||||||
async def work(item: int) -> None:
|
async def work(item: int) -> None:
|
||||||
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
|
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print("worked on", item)
|
print("worked on", item, flush=True)
|
||||||
|
|
||||||
|
|
||||||
async def worker(q: asyncio.Queue) -> None:
|
async def worker(q: asyncio.Queue) -> None:
|
||||||
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
|
"""Simulates doing asynchronous work that takes a bit of time to finish."""
|
||||||
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
|
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
|
||||||
while True:
|
while True:
|
||||||
# We want to block here, until we can get the next item from the queue.
|
# We want to block here, until we can get the next item from the queue.
|
||||||
@ -67,19 +67,19 @@ async def main() -> None:
|
|||||||
q.put_nowait(item)
|
q.put_nowait(item)
|
||||||
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
|
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
|
||||||
await pool.start(3) # launches three worker tasks
|
await pool.start(3) # launches three worker tasks
|
||||||
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
|
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
|
||||||
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||||
await q.join()
|
await q.join()
|
||||||
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
|
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
|
||||||
control_server_task.cancel()
|
control_server_task.cancel()
|
||||||
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
||||||
# we can now safely cancel their tasks.
|
# we can now safely cancel their tasks.
|
||||||
pool.stop_all()
|
|
||||||
pool.lock()
|
pool.lock()
|
||||||
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
|
pool.stop_all()
|
||||||
|
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
|
||||||
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
||||||
# we just silently collect their exceptions along with their return values.
|
# we just silently collect their exceptions along with their return values.
|
||||||
await pool.gather(return_exceptions=True)
|
await pool.gather_and_close(return_exceptions=True)
|
||||||
await control_server_task
|
await control_server_task
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user