diff --git a/src/asyncio_taskpool/control/__main__.py b/src/asyncio_taskpool/control/__main__.py index 86ed53a..2e65fee 100644 --- a/src/asyncio_taskpool/control/__main__.py +++ b/src/asyncio_taskpool/control/__main__.py @@ -19,11 +19,10 @@ CLI client entry point. """ -import sys from argparse import ArgumentParser from asyncio import run from pathlib import Path -from typing import Dict, Any +from typing import Any, Dict, Sequence from ..constants import PACKAGE_NAME from ..pool import TaskPool @@ -31,25 +30,19 @@ from .client import ControlClient, TCPControlClient, UnixControlClient from .server import TCPControlServer, UnixControlServer -CONN_TYPE = 'conn_type' +CLIENT_CLASS = 'client_class' UNIX, TCP = 'unix', 'tcp' SOCKET_PATH = 'path' HOST, PORT = 'host', 'port' -def parse_cli() -> Dict[str, Any]: +def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]: parser = ArgumentParser( - prog=PACKAGE_NAME, - description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}" - ) - subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE) - unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket") - unix_parser.add_argument( - SOCKET_PATH, - type=Path, - help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is " - f"listening." + prog=f'{PACKAGE_NAME}.control', + description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}" ) + subparsers = parser.add_subparsers(title="Connection types") + tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket") tcp_parser.add_argument( HOST, @@ -60,19 +53,25 @@ def parse_cli() -> Dict[str, Any]: type=int, help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." ) - return vars(parser.parse_args()) + tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient}) + + unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket") + unix_parser.add_argument( + SOCKET_PATH, + type=Path, + help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is " + f"listening." + ) + unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient}) + + return vars(parser.parse_args(args)) async def main(): kwargs = parse_cli() - if kwargs[CONN_TYPE] == UNIX: - client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) - elif kwargs[CONN_TYPE] == TCP: - client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT]) - else: - print("Invalid connection type", file=sys.stderr) - sys.exit(2) - await client.start() + client_cls = kwargs.pop(CLIENT_CLASS) + await client_cls(**kwargs).start() + if __name__ == '__main__': run(main()) diff --git a/tests/test_control/test___main__.py b/tests/test_control/test___main__.py new file mode 100644 index 0000000..62a5747 --- /dev/null +++ b/tests/test_control/test___main__.py @@ -0,0 +1,45 @@ +from pathlib import Path +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient +from asyncio_taskpool.control import __main__ as module + + +class CLITestCase(IsolatedAsyncioTestCase): + + def test_parse_cli(self): + socket_path = '/some/path/to.sock' + args = [module.UNIX, socket_path] + expected_kwargs = { + module.CLIENT_CLASS: UnixControlClient, + module.SOCKET_PATH: Path(socket_path) + } + parsed_kwargs = module.parse_cli(args) + self.assertDictEqual(expected_kwargs, parsed_kwargs) + + host, port = '1.2.3.4', '1234' + args = [module.TCP, host, port] + expected_kwargs = { + module.CLIENT_CLASS: TCPControlClient, + module.HOST: host, + module.PORT: int(port) + } + parsed_kwargs = module.parse_cli(args) + self.assertDictEqual(expected_kwargs, parsed_kwargs) + + with patch('sys.stderr'): + with self.assertRaises(SystemExit): + module.parse_cli(['invalid', 'foo', 'bar']) + + @patch.object(module, 'parse_cli') + async def test_main(self, mock_parse_cli: MagicMock): + mock_client_start = AsyncMock() + mock_client = MagicMock(start=mock_client_start) + mock_client_cls = MagicMock(return_value=mock_client) + mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789} + mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs + self.assertIsNone(await module.main()) + mock_parse_cli.assert_called_once_with() + mock_client_cls.assert_called_once_with(**mock_client_kwargs) + mock_client_start.assert_awaited_once_with()