generated from daniil-berg/boilerplate-py
full unit test coverage and docstrings for client
module; minor refactoring
This commit is contained in:
parent
bac7b32342
commit
d0c0177681
@ -25,54 +25,116 @@ import sys
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||||
from .types import ClientConnT
|
from .types import ClientConnT, PathT
|
||||||
|
|
||||||
|
|
||||||
class ControlClient(ABC):
|
class ControlClient(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for a simple implementation of a task pool control client.
|
||||||
|
|
||||||
@abstractmethod
|
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
|
||||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
|
||||||
raise NotImplementedError
|
This is a minimal working implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def client_info() -> dict:
|
def client_info() -> dict:
|
||||||
|
"""Returns a dictionary of client information relevant for the handshake with the server."""
|
||||||
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
"""
|
||||||
|
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
|
||||||
|
|
||||||
|
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
|
||||||
|
as keyword-arguments.
|
||||||
|
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
|
||||||
|
`None` and `None`, if it failed to establish the defined connection.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
|
"""Simply stores the connection keyword-arguments necessary for opening the connection."""
|
||||||
self._conn_kwargs = conn_kwargs
|
self._conn_kwargs = conn_kwargs
|
||||||
self._connected: bool = False
|
self._connected: bool = False
|
||||||
|
|
||||||
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
Performs the first interaction with the server providing it with the necessary client information.
|
||||||
|
|
||||||
|
Upon completion, the server's info is printed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
"""
|
||||||
self._connected = True
|
self._connected = True
|
||||||
writer.write(json.dumps(self.client_info()).encode())
|
writer.write(json.dumps(self.client_info()).encode())
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
def _get_command(self, writer: StreamWriter) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
||||||
|
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
msg = input("> ").strip().lower()
|
||||||
except EOFError:
|
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command.
|
||||||
msg = CLIENT_EXIT
|
msg = CLIENT_EXIT
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||||
print()
|
print()
|
||||||
return
|
return
|
||||||
if msg == CLIENT_EXIT:
|
if msg == CLIENT_EXIT:
|
||||||
writer.close()
|
writer.close()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
return
|
return
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
"""
|
||||||
|
Reacts to the user's command, potentially performing a back-and-forth interaction with the server.
|
||||||
|
|
||||||
|
If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`.
|
||||||
|
If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: The `asyncio.StreamReader` returned by the `_open_connection()` method
|
||||||
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
"""
|
||||||
|
cmd = self._get_command(writer)
|
||||||
|
if cmd is None:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
writer.write(msg.encode())
|
# Send the command to the server.
|
||||||
|
writer.write(cmd.encode())
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
# Await the server's response, then print it.
|
||||||
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
print((await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
|
|
||||||
async def start(self):
|
async def start(self) -> None:
|
||||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
"""
|
||||||
|
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop.
|
||||||
|
|
||||||
|
If the connection can not be established, an error message is printed to `stderr` and the method returns.
|
||||||
|
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a
|
||||||
|
disconnected-message.
|
||||||
|
"""
|
||||||
|
reader, writer = await self._open_connection(**self._conn_kwargs)
|
||||||
if reader is None:
|
if reader is None:
|
||||||
print("Failed to connect.", file=sys.stderr)
|
print("Failed to connect.", file=sys.stderr)
|
||||||
return
|
return
|
||||||
@ -83,11 +145,24 @@ class ControlClient(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class UnixControlClient(ControlClient):
|
class UnixControlClient(ControlClient):
|
||||||
def __init__(self, **conn_kwargs) -> None:
|
"""Task pool control client that expects a unix socket to be exposed by the control server."""
|
||||||
self._socket_path = Path(conn_kwargs.pop('path'))
|
|
||||||
|
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
|
||||||
|
"""
|
||||||
|
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
|
||||||
|
|
||||||
|
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
|
||||||
|
"""
|
||||||
|
self._socket_path = Path(socket_path)
|
||||||
super().__init__(**conn_kwargs)
|
super().__init__(**conn_kwargs)
|
||||||
|
|
||||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
"""
|
||||||
|
Wrapper around the `asyncio.open_unix_connection` function.
|
||||||
|
|
||||||
|
Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path;
|
||||||
|
otherwise, the stream-reader and -writer tuple is returned.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return await open_unix_connection(self._socket_path, **kwargs)
|
return await open_unix_connection(self._socket_path, **kwargs)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
@ -20,6 +20,7 @@ Custom type definitions used in various modules.
|
|||||||
|
|
||||||
|
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
|
||||||
@ -36,3 +37,5 @@ CancelCallbackT = Callable
|
|||||||
|
|
||||||
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||||
|
|
||||||
|
PathT = Union[Path, str]
|
||||||
|
207
tests/test_client.py
Normal file
207
tests/test_client.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Unittests for the `asyncio_taskpool.client` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool import client
|
||||||
|
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
|
class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
|
||||||
|
self.print_patcher = patch.object(client, 'print')
|
||||||
|
self.mock_abstract_methods = self.abstract_patcher.start()
|
||||||
|
self.mock_print = self.print_patcher.start()
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.client = client.ControlClient(**self.kwargs)
|
||||||
|
|
||||||
|
self.mock_read = AsyncMock(return_value=FOO.encode())
|
||||||
|
self.mock_write, self.mock_drain = MagicMock(), AsyncMock()
|
||||||
|
self.mock_reader = MagicMock(read=self.mock_read)
|
||||||
|
self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.abstract_patcher.stop()
|
||||||
|
self.print_patcher.stop()
|
||||||
|
|
||||||
|
def test_client_info(self):
|
||||||
|
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
|
||||||
|
client.ControlClient.client_info())
|
||||||
|
|
||||||
|
async def test_abstract(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
await self.client._open_connection(**self.kwargs)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.kwargs, self.client._conn_kwargs)
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, 'client_info')
|
||||||
|
async def test__server_handshake(self, mock_client_info: MagicMock):
|
||||||
|
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
|
||||||
|
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
mock_client_info.assert_called_once_with()
|
||||||
|
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
|
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
|
||||||
|
|
||||||
|
@patch.object(client, 'input')
|
||||||
|
def test__get_command(self, mock_input: MagicMock):
|
||||||
|
self.client._connected = True
|
||||||
|
|
||||||
|
mock_input.return_value = ' ' + FOO.upper() + ' '
|
||||||
|
mock_close = MagicMock()
|
||||||
|
mock_writer = MagicMock(close=mock_close)
|
||||||
|
output = self.client._get_command(mock_writer)
|
||||||
|
self.assertEqual(FOO, output)
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock_input.reset_mock()
|
||||||
|
mock_input.side_effect = KeyboardInterrupt
|
||||||
|
self.assertIsNone(self.client._get_command(mock_writer))
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock_input.reset_mock()
|
||||||
|
mock_input.side_effect = EOFError
|
||||||
|
self.assertIsNone(self.client._get_command(mock_writer))
|
||||||
|
mock_input.assert_called_once()
|
||||||
|
mock_close.assert_called_once()
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, '_get_command')
|
||||||
|
async def test__interact(self, mock__get_command: MagicMock):
|
||||||
|
self.client._connected = True
|
||||||
|
|
||||||
|
mock__get_command.return_value = None
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_not_called()
|
||||||
|
self.mock_drain.assert_not_awaited()
|
||||||
|
self.mock_read.assert_not_awaited()
|
||||||
|
self.mock_print.assert_not_called()
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||||
|
self.mock_drain.side_effect = err = ConnectionError()
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||||
|
self.assertFalse(self.client._connected)
|
||||||
|
|
||||||
|
self.client._connected = True
|
||||||
|
self.mock_write.reset_mock()
|
||||||
|
self.mock_drain.reset_mock(side_effect=True)
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
|
self.mock_drain.assert_awaited_once_with()
|
||||||
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
|
self.mock_print.assert_called_once_with(FOO)
|
||||||
|
self.assertTrue(self.client._connected)
|
||||||
|
|
||||||
|
@patch.object(client.ControlClient, '_interact')
|
||||||
|
@patch.object(client.ControlClient, '_server_handshake')
|
||||||
|
@patch.object(client.ControlClient, '_open_connection')
|
||||||
|
async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock,
|
||||||
|
mock__interact: AsyncMock):
|
||||||
|
mock__open_connection.return_value = None, None
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_not_awaited()
|
||||||
|
mock__interact.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr)
|
||||||
|
|
||||||
|
mock__open_connection.reset_mock()
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
mock__open_connection.return_value = self.mock_reader, self.mock_writer
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
mock__interact.assert_not_awaited()
|
||||||
|
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||||
|
|
||||||
|
mock__open_connection.reset_mock()
|
||||||
|
mock__server_handshake.reset_mock()
|
||||||
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
|
self.client._connected = True
|
||||||
|
def disconnect(*_args, **_kwargs) -> None: self.client._connected = False
|
||||||
|
mock__interact.side_effect = disconnect
|
||||||
|
self.assertIsNone(await self.client.start())
|
||||||
|
mock__open_connection.assert_awaited_once_with(**self.kwargs)
|
||||||
|
mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer)
|
||||||
|
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
|
||||||
|
self.mock_base_init = self.base_init_patcher.start()
|
||||||
|
self.path = '/tmp/asyncio_taskpool'
|
||||||
|
self.kwargs = {FOO: 123, BAR: 456}
|
||||||
|
self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.base_init_patcher.stop()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(Path(self.path), self.client._socket_path)
|
||||||
|
self.mock_base_init.assert_called_once_with(**self.kwargs)
|
||||||
|
|
||||||
|
@patch.object(client, 'print')
|
||||||
|
@patch.object(client, 'open_unix_connection')
|
||||||
|
async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock):
|
||||||
|
mock_open_unix_connection.return_value = expected_output = 'something'
|
||||||
|
kwargs = {'a': 1, 'b': 2}
|
||||||
|
output = await self.client._open_connection(**kwargs)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||||
|
mock_print.assert_not_called()
|
||||||
|
|
||||||
|
mock_open_unix_connection.reset_mock()
|
||||||
|
|
||||||
|
mock_open_unix_connection.side_effect = FileNotFoundError
|
||||||
|
output1, output2 = await self.client._open_connection(**kwargs)
|
||||||
|
self.assertIsNone(output1)
|
||||||
|
self.assertIsNone(output2)
|
||||||
|
mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs)
|
||||||
|
mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)
|
Loading…
Reference in New Issue
Block a user