generated from daniil-berg/boilerplate-py
finally reached 100% unittest coverage overall
This commit is contained in:
parent
6f082288d8
commit
38f4ec1b06
@ -19,11 +19,10 @@ CLI client entry point.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from asyncio import run
|
from asyncio import run
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
from ..constants import PACKAGE_NAME
|
from ..constants import PACKAGE_NAME
|
||||||
from ..pool import TaskPool
|
from ..pool import TaskPool
|
||||||
@ -31,25 +30,19 @@ from .client import ControlClient, TCPControlClient, UnixControlClient
|
|||||||
from .server import TCPControlServer, UnixControlServer
|
from .server import TCPControlServer, UnixControlServer
|
||||||
|
|
||||||
|
|
||||||
CONN_TYPE = 'conn_type'
|
CLIENT_CLASS = 'client_class'
|
||||||
UNIX, TCP = 'unix', 'tcp'
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
SOCKET_PATH = 'path'
|
SOCKET_PATH = 'path'
|
||||||
HOST, PORT = 'host', 'port'
|
HOST, PORT = 'host', 'port'
|
||||||
|
|
||||||
|
|
||||||
def parse_cli() -> Dict[str, Any]:
|
def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog=PACKAGE_NAME,
|
prog=f'{PACKAGE_NAME}.control',
|
||||||
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
|
description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
|
||||||
)
|
|
||||||
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
|
|
||||||
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
|
|
||||||
unix_parser.add_argument(
|
|
||||||
SOCKET_PATH,
|
|
||||||
type=Path,
|
|
||||||
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
|
|
||||||
f"listening."
|
|
||||||
)
|
)
|
||||||
|
subparsers = parser.add_subparsers(title="Connection types")
|
||||||
|
|
||||||
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
|
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
|
||||||
tcp_parser.add_argument(
|
tcp_parser.add_argument(
|
||||||
HOST,
|
HOST,
|
||||||
@ -60,19 +53,25 @@ def parse_cli() -> Dict[str, Any]:
|
|||||||
type=int,
|
type=int,
|
||||||
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
||||||
)
|
)
|
||||||
return vars(parser.parse_args())
|
tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient})
|
||||||
|
|
||||||
|
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
|
||||||
|
unix_parser.add_argument(
|
||||||
|
SOCKET_PATH,
|
||||||
|
type=Path,
|
||||||
|
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
|
||||||
|
f"listening."
|
||||||
|
)
|
||||||
|
unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient})
|
||||||
|
|
||||||
|
return vars(parser.parse_args(args))
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
kwargs = parse_cli()
|
kwargs = parse_cli()
|
||||||
if kwargs[CONN_TYPE] == UNIX:
|
client_cls = kwargs.pop(CLIENT_CLASS)
|
||||||
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
|
await client_cls(**kwargs).start()
|
||||||
elif kwargs[CONN_TYPE] == TCP:
|
|
||||||
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
|
|
||||||
else:
|
|
||||||
print("Invalid connection type", file=sys.stderr)
|
|
||||||
sys.exit(2)
|
|
||||||
await client.start()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run(main())
|
run(main())
|
||||||
|
45
tests/test_control/test___main__.py
Normal file
45
tests/test_control/test___main__.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient
|
||||||
|
from asyncio_taskpool.control import __main__ as module
|
||||||
|
|
||||||
|
|
||||||
|
class CLITestCase(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
def test_parse_cli(self):
|
||||||
|
socket_path = '/some/path/to.sock'
|
||||||
|
args = [module.UNIX, socket_path]
|
||||||
|
expected_kwargs = {
|
||||||
|
module.CLIENT_CLASS: UnixControlClient,
|
||||||
|
module.SOCKET_PATH: Path(socket_path)
|
||||||
|
}
|
||||||
|
parsed_kwargs = module.parse_cli(args)
|
||||||
|
self.assertDictEqual(expected_kwargs, parsed_kwargs)
|
||||||
|
|
||||||
|
host, port = '1.2.3.4', '1234'
|
||||||
|
args = [module.TCP, host, port]
|
||||||
|
expected_kwargs = {
|
||||||
|
module.CLIENT_CLASS: TCPControlClient,
|
||||||
|
module.HOST: host,
|
||||||
|
module.PORT: int(port)
|
||||||
|
}
|
||||||
|
parsed_kwargs = module.parse_cli(args)
|
||||||
|
self.assertDictEqual(expected_kwargs, parsed_kwargs)
|
||||||
|
|
||||||
|
with patch('sys.stderr'):
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
module.parse_cli(['invalid', 'foo', 'bar'])
|
||||||
|
|
||||||
|
@patch.object(module, 'parse_cli')
|
||||||
|
async def test_main(self, mock_parse_cli: MagicMock):
|
||||||
|
mock_client_start = AsyncMock()
|
||||||
|
mock_client = MagicMock(start=mock_client_start)
|
||||||
|
mock_client_cls = MagicMock(return_value=mock_client)
|
||||||
|
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
|
||||||
|
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
|
||||||
|
self.assertIsNone(await module.main())
|
||||||
|
mock_parse_cli.assert_called_once_with()
|
||||||
|
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
|
||||||
|
mock_client_start.assert_awaited_once_with()
|
Loading…
Reference in New Issue
Block a user