full unit test coverage and docstrings for client module; minor refactoring

This commit is contained in:
2022-02-19 12:56:08 +01:00
parent bac7b32342
commit d0c0177681
3 changed files with 298 additions and 13 deletions

View File

@ -25,54 +25,116 @@ import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from typing import Optional
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT
from .types import ClientConnT, PathT
class ControlClient(ABC):
"""
Abstract base class for a simple implementation of a task pool control client.
@abstractmethod
async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
This is a minimal working implementation.
"""
@staticmethod
def client_info() -> dict:
"""Returns a dictionary of client information relevant for the handshake with the server."""
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
@abstractmethod
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
as keyword-arguments.
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
`None` and `None`, if it failed to establish the defined connection.
"""
raise NotImplementedError
def __init__(self, **conn_kwargs) -> None:
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Performs the first interaction with the server providing it with the necessary client information.
Upon completion, the server's info is printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
def _get_command(self, writer: StreamWriter) -> Optional[str]:
"""
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
Args:
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
"""
try:
msg = input("> ").strip().lower()
except EOFError:
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
msg = CLIENT_EXIT
except KeyboardInterrupt:
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print()
return
if msg == CLIENT_EXIT:
writer.close()
self._connected = False
return
return msg
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
Args:
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
cmd = self._get_command(writer)
if cmd is None:
return
try:
writer.write(msg.encode())
# Send the command to the server.
writer.write(cmd.encode())
await writer.drain()
except ConnectionError as e:
self._connected = False
print(e, file=sys.stderr)
return
# Await the server's response, then print it.
print((await reader.read(SESSION_MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)
async def start(self) -> None:
"""
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
If the connection can not be established, an error message is printed to `stderr` and the method returns.
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
disconnected-message.
"""
reader, writer = await self._open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
@ -83,11 +145,24 @@ class ControlClient(ABC):
class UnixControlClient(ControlClient):
def __init__(self, **conn_kwargs) -> None:
self._socket_path = Path(conn_kwargs.pop('path'))
"""Task pool control client that expects a unix socket to be exposed by the control server."""
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
"""
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
"""
self._socket_path = Path(socket_path)
super().__init__(**conn_kwargs)
async def open_connection(self, **kwargs) -> ClientConnT:
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_unix_connection` function.
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:

View File

@ -20,6 +20,7 @@ Custom type definitions used in various modules.
from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
@ -36,3 +37,5 @@ CancelCallbackT = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
PathT = Union[Path, str]