finally reached 100% unittest coverage overall

This commit is contained in:
Daniil Fajnberg 2022-03-13 16:11:20 +01:00
parent 6f082288d8
commit 38f4ec1b06
2 changed files with 67 additions and 23 deletions

View File

@ -19,11 +19,10 @@ CLI client entry point.
"""
import sys
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Dict, Any
from typing import Any, Dict, Sequence
from ..constants import PACKAGE_NAME
from ..pool import TaskPool
@ -31,25 +30,19 @@ from .client import ControlClient, TCPControlClient, UnixControlClient
from .server import TCPControlServer, UnixControlServer
CONN_TYPE = 'conn_type'
CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port'
def parse_cli() -> Dict[str, Any]:
def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
parser = ArgumentParser(
prog=PACKAGE_NAME,
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
prog=f'{PACKAGE_NAME}.control',
description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types")
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument(
HOST,
@ -60,19 +53,25 @@ def parse_cli() -> Dict[str, Any]:
type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
return vars(parser.parse_args())
tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient})
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient})
return vars(parser.parse_args(args))
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
client_cls = kwargs.pop(CLIENT_CLASS)
await client_cls(**kwargs).start()
if __name__ == '__main__':
run(main())

View File

@ -0,0 +1,45 @@
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient
from asyncio_taskpool.control import __main__ as module
class CLITestCase(IsolatedAsyncioTestCase):
def test_parse_cli(self):
socket_path = '/some/path/to.sock'
args = [module.UNIX, socket_path]
expected_kwargs = {
module.CLIENT_CLASS: UnixControlClient,
module.SOCKET_PATH: Path(socket_path)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
host, port = '1.2.3.4', '1234'
args = [module.TCP, host, port]
expected_kwargs = {
module.CLIENT_CLASS: TCPControlClient,
module.HOST: host,
module.PORT: int(port)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
with patch('sys.stderr'):
with self.assertRaises(SystemExit):
module.parse_cli(['invalid', 'foo', 'bar'])
@patch.object(module, 'parse_cli')
async def test_main(self, mock_parse_cli: MagicMock):
mock_client_start = AsyncMock()
mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
mock_client_start.assert_awaited_once_with()