finally reached 100% unittest coverage overall

This commit is contained in:
Daniil Fajnberg 2022-03-13 16:11:20 +01:00
parent 6f082288d8
commit 38f4ec1b06
2 changed files with 67 additions and 23 deletions

View File

@ -19,11 +19,10 @@ CLI client entry point.
""" """
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from asyncio import run from asyncio import run
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Any, Dict, Sequence
from ..constants import PACKAGE_NAME from ..constants import PACKAGE_NAME
from ..pool import TaskPool from ..pool import TaskPool
@ -31,25 +30,19 @@ from .client import ControlClient, TCPControlClient, UnixControlClient
from .server import TCPControlServer, UnixControlServer from .server import TCPControlServer, UnixControlServer
CONN_TYPE = 'conn_type' CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp' UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path' SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port' HOST, PORT = 'host', 'port'
def parse_cli() -> Dict[str, Any]: def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
parser = ArgumentParser( parser = ArgumentParser(
prog=PACKAGE_NAME, prog=f'{PACKAGE_NAME}.control',
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}" description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
) )
subparsers = parser.add_subparsers(title="Connection types")
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket") tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument( tcp_parser.add_argument(
HOST, HOST,
@ -60,19 +53,25 @@ def parse_cli() -> Dict[str, Any]:
type=int, type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
) )
return vars(parser.parse_args()) tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient})
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient})
return vars(parser.parse_args(args))
async def main(): async def main():
kwargs = parse_cli() kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX: client_cls = kwargs.pop(CLIENT_CLASS)
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) await client_cls(**kwargs).start()
elif kwargs[CONN_TYPE] == TCP:
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
if __name__ == '__main__': if __name__ == '__main__':
run(main()) run(main())

View File

@ -0,0 +1,45 @@
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient
from asyncio_taskpool.control import __main__ as module
class CLITestCase(IsolatedAsyncioTestCase):
def test_parse_cli(self):
socket_path = '/some/path/to.sock'
args = [module.UNIX, socket_path]
expected_kwargs = {
module.CLIENT_CLASS: UnixControlClient,
module.SOCKET_PATH: Path(socket_path)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
host, port = '1.2.3.4', '1234'
args = [module.TCP, host, port]
expected_kwargs = {
module.CLIENT_CLASS: TCPControlClient,
module.HOST: host,
module.PORT: int(port)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
with patch('sys.stderr'):
with self.assertRaises(SystemExit):
module.parse_cli(['invalid', 'foo', 'bar'])
@patch.object(module, 'parse_cli')
async def test_main(self, mock_parse_cli: MagicMock):
mock_client_start = AsyncMock()
mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
mock_client_start.assert_awaited_once_with()