diff --git a/src/asyncio_taskpool/client.py b/src/asyncio_taskpool/client.py index 0362ce5..bb1a28f 100644 --- a/src/asyncio_taskpool/client.py +++ b/src/asyncio_taskpool/client.py @@ -25,54 +25,116 @@ import sys from abc import ABC, abstractmethod from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from pathlib import Path +from typing import Optional from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES -from .types import ClientConnT +from .types import ClientConnT, PathT class ControlClient(ABC): + """ + Abstract base class for a simple implementation of a task pool control client. - @abstractmethod - async def open_connection(self, **kwargs) -> ClientConnT: - raise NotImplementedError + Since the server's control interface is simply expecting commands to be sent, any process able to connect to the + TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well. + This is a minimal working implementation. + """ @staticmethod def client_info() -> dict: + """Returns a dictionary of client information relevant for the handshake with the server.""" return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns} + @abstractmethod + async def _open_connection(self, **kwargs) -> ClientConnT: + """ + Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair. + + This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked) + as keyword-arguments. + This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of + `None` and `None`, if it failed to establish the defined connection. + """ + raise NotImplementedError + def __init__(self, **conn_kwargs) -> None: + """Simply stores the connection keyword-arguments necessary for opening the connection.""" self._conn_kwargs = conn_kwargs self._connected: bool = False async def _server_handshake(self, reader: StreamReader, writer: StreamWriter) -> None: + """ + Performs the first interaction with the server providing it with the necessary client information. + + Upon completion, the server's info is printed. + + Args: + reader: The `asyncio.StreamReader` returned by the `_open_connection()` method + writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method + """ self._connected = True writer.write(json.dumps(self.client_info()).encode()) await writer.drain() print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) - async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: + def _get_command(self, writer: StreamWriter) -> Optional[str]: + """ + Prompts the user for input and either returns it (after cleaning it up) or `None` in special cases. + + Args: + writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method + + Returns: + `None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect; + otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase. + """ try: msg = input("> ").strip().lower() - except EOFError: + except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command. msg = CLIENT_EXIT - except KeyboardInterrupt: + except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt. print() return if msg == CLIENT_EXIT: writer.close() self._connected = False return + return msg + + async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: + """ + Reacts to the user's command, potentially performing a back-and-forth interaction with the server. + + If `_get_command` returns `None`, this may imply that the client disconnected, but may also just be `Ctrl+C`. + If an actual command is retrieved, it is written to the stream, a response is awaited and eventually printed. + + Args: + reader: The `asyncio.StreamReader` returned by the `_open_connection()` method + writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method + """ + cmd = self._get_command(writer) + if cmd is None: + return try: - writer.write(msg.encode()) + # Send the command to the server. + writer.write(cmd.encode()) await writer.drain() except ConnectionError as e: self._connected = False print(e, file=sys.stderr) return + # Await the server's response, then print it. print((await reader.read(SESSION_MSG_BYTES)).decode()) - async def start(self): - reader, writer = await self.open_connection(**self._conn_kwargs) + async def start(self) -> None: + """ + This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop. + + If the connection can not be established, an error message is printed to `stderr` and the method returns. + If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a + disconnected-message. + """ + reader, writer = await self._open_connection(**self._conn_kwargs) if reader is None: print("Failed to connect.", file=sys.stderr) return @@ -83,11 +145,24 @@ class ControlClient(ABC): class UnixControlClient(ControlClient): - def __init__(self, **conn_kwargs) -> None: - self._socket_path = Path(conn_kwargs.pop('path')) + """Task pool control client that expects a unix socket to be exposed by the control server.""" + + def __init__(self, socket_path: PathT, **conn_kwargs) -> None: + """ + In addition to what the base class does, the `socket_path` is expected as a non-optional argument. + + The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument. + """ + self._socket_path = Path(socket_path) super().__init__(**conn_kwargs) - async def open_connection(self, **kwargs) -> ClientConnT: + async def _open_connection(self, **kwargs) -> ClientConnT: + """ + Wrapper around the `asyncio.open_unix_connection` function. + + Returns a tuple of `None` and `None`, if the socket is not found at the pre-defined path; + otherwise, the stream-reader and -writer tuple is returned. + """ try: return await open_unix_connection(self._socket_path, **kwargs) except FileNotFoundError: diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index 24f373f..f6177e7 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -20,6 +20,7 @@ Custom type definitions used in various modules. from asyncio.streams import StreamReader, StreamWriter +from pathlib import Path from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union @@ -36,3 +37,5 @@ CancelCallbackT = Callable ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] + +PathT = Union[Path, str] diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..d101f47 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,207 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +Unittests for the `asyncio_taskpool.client` module. +""" + + +import json +import shutil +import sys +from pathlib import Path +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from asyncio_taskpool import client +from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES + + +FOO, BAR = 'foo', 'bar' + + +class ControlClientTestCase(IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set()) + self.print_patcher = patch.object(client, 'print') + self.mock_abstract_methods = self.abstract_patcher.start() + self.mock_print = self.print_patcher.start() + self.kwargs = {FOO: 123, BAR: 456} + self.client = client.ControlClient(**self.kwargs) + + self.mock_read = AsyncMock(return_value=FOO.encode()) + self.mock_write, self.mock_drain = MagicMock(), AsyncMock() + self.mock_reader = MagicMock(read=self.mock_read) + self.mock_writer = MagicMock(write=self.mock_write, drain=self.mock_drain) + + def tearDown(self) -> None: + self.abstract_patcher.stop() + self.print_patcher.stop() + + def test_client_info(self): + self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}, + client.ControlClient.client_info()) + + async def test_abstract(self): + with self.assertRaises(NotImplementedError): + await self.client._open_connection(**self.kwargs) + + def test_init(self): + self.assertEqual(self.kwargs, self.client._conn_kwargs) + self.assertFalse(self.client._connected) + + @patch.object(client.ControlClient, 'client_info') + async def test__server_handshake(self, mock_client_info: MagicMock): + mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999} + self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer)) + self.assertTrue(self.client._connected) + mock_client_info.assert_called_once_with() + self.mock_write.assert_called_once_with(json.dumps(mock_info).encode()) + self.mock_drain.assert_awaited_once_with() + self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) + self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode()) + + @patch.object(client, 'input') + def test__get_command(self, mock_input: MagicMock): + self.client._connected = True + + mock_input.return_value = ' ' + FOO.upper() + ' ' + mock_close = MagicMock() + mock_writer = MagicMock(close=mock_close) + output = self.client._get_command(mock_writer) + self.assertEqual(FOO, output) + mock_input.assert_called_once() + mock_close.assert_not_called() + self.assertTrue(self.client._connected) + + mock_input.reset_mock() + mock_input.side_effect = KeyboardInterrupt + self.assertIsNone(self.client._get_command(mock_writer)) + mock_input.assert_called_once() + mock_close.assert_not_called() + self.assertTrue(self.client._connected) + + mock_input.reset_mock() + mock_input.side_effect = EOFError + self.assertIsNone(self.client._get_command(mock_writer)) + mock_input.assert_called_once() + mock_close.assert_called_once() + self.assertFalse(self.client._connected) + + @patch.object(client.ControlClient, '_get_command') + async def test__interact(self, mock__get_command: MagicMock): + self.client._connected = True + + mock__get_command.return_value = None + self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) + self.mock_write.assert_not_called() + self.mock_drain.assert_not_awaited() + self.mock_read.assert_not_awaited() + self.mock_print.assert_not_called() + self.assertTrue(self.client._connected) + + mock__get_command.return_value = cmd = FOO + BAR + ' 123' + self.mock_drain.side_effect = err = ConnectionError() + self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) + self.mock_write.assert_called_once_with(cmd.encode()) + self.mock_drain.assert_awaited_once_with() + self.mock_read.assert_not_awaited() + self.mock_print.assert_called_once_with(err, file=sys.stderr) + self.assertFalse(self.client._connected) + + self.client._connected = True + self.mock_write.reset_mock() + self.mock_drain.reset_mock(side_effect=True) + self.mock_print.reset_mock() + + self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) + self.mock_write.assert_called_once_with(cmd.encode()) + self.mock_drain.assert_awaited_once_with() + self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) + self.mock_print.assert_called_once_with(FOO) + self.assertTrue(self.client._connected) + + @patch.object(client.ControlClient, '_interact') + @patch.object(client.ControlClient, '_server_handshake') + @patch.object(client.ControlClient, '_open_connection') + async def test_start(self, mock__open_connection: AsyncMock, mock__server_handshake: AsyncMock, + mock__interact: AsyncMock): + mock__open_connection.return_value = None, None + self.assertIsNone(await self.client.start()) + mock__open_connection.assert_awaited_once_with(**self.kwargs) + mock__server_handshake.assert_not_awaited() + mock__interact.assert_not_awaited() + self.mock_print.assert_called_once_with("Failed to connect.", file=sys.stderr) + + mock__open_connection.reset_mock() + self.mock_print.reset_mock() + + mock__open_connection.return_value = self.mock_reader, self.mock_writer + self.assertIsNone(await self.client.start()) + mock__open_connection.assert_awaited_once_with(**self.kwargs) + mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer) + mock__interact.assert_not_awaited() + self.mock_print.assert_called_once_with("Disconnected from control server.") + + mock__open_connection.reset_mock() + mock__server_handshake.reset_mock() + self.mock_print.reset_mock() + + self.client._connected = True + def disconnect(*_args, **_kwargs) -> None: self.client._connected = False + mock__interact.side_effect = disconnect + self.assertIsNone(await self.client.start()) + mock__open_connection.assert_awaited_once_with(**self.kwargs) + mock__server_handshake.assert_awaited_once_with(self.mock_reader, self.mock_writer) + mock__interact.assert_awaited_once_with(self.mock_reader, self.mock_writer) + self.mock_print.assert_called_once_with("Disconnected from control server.") + + +class UnixControlClientTestCase(IsolatedAsyncioTestCase): + + def setUp(self) -> None: + self.base_init_patcher = patch.object(client.ControlClient, '__init__') + self.mock_base_init = self.base_init_patcher.start() + self.path = '/tmp/asyncio_taskpool' + self.kwargs = {FOO: 123, BAR: 456} + self.client = client.UnixControlClient(socket_path=self.path, **self.kwargs) + + def tearDown(self) -> None: + self.base_init_patcher.stop() + + def test_init(self): + self.assertEqual(Path(self.path), self.client._socket_path) + self.mock_base_init.assert_called_once_with(**self.kwargs) + + @patch.object(client, 'print') + @patch.object(client, 'open_unix_connection') + async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock): + mock_open_unix_connection.return_value = expected_output = 'something' + kwargs = {'a': 1, 'b': 2} + output = await self.client._open_connection(**kwargs) + self.assertEqual(expected_output, output) + mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs) + mock_print.assert_not_called() + + mock_open_unix_connection.reset_mock() + + mock_open_unix_connection.side_effect = FileNotFoundError + output1, output2 = await self.client._open_connection(**kwargs) + self.assertIsNone(output1) + self.assertIsNone(output2) + mock_open_unix_connection.assert_awaited_once_with(Path(self.path), **kwargs) + mock_print.assert_called_once_with("No socket at", Path(self.path), file=sys.stderr)