full unit test coverage and docstrings for session module

This commit is contained in:
Daniil Fajnberg 2022-02-14 17:59:11 +01:00
parent 3f3eb7ce38
commit 96d01e7259
5 changed files with 516 additions and 24 deletions

View File

@ -8,5 +8,8 @@ omit =
fail_under = 100 fail_under = 100
show_missing = True show_missing = True
skip_covered = False skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
omit = omit =
tests/* tests/*

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.3.2 version = 0.3.3
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -1,6 +1,27 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the the definition of the control session class used by the control server.
"""
import logging import logging
import json import json
from argparse import ArgumentError, HelpFormatter, Namespace from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
@ -18,7 +39,29 @@ log = logging.getLogger(__name__)
class ControlSession: class ControlSession:
"""
This class defines the API for controlling a task pool instance from the outside.
The commands received from a connected client are translated into method calls on the task pool instance.
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
"""
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
"""
Instantiation should happen once a client connection to the control server has already been established.
For more convenient/efficient access, some of the server's properties are saved in separate attributes.
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
Args:
server:
The instance of a `ControlServer` subclass starting the session.
reader:
The `asyncio.StreamReader` created when a client connected to the server.
writer:
The `asyncio.StreamWriter` created when a client connected to the server.
"""
self._control_server: 'ControlServer' = server self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name self._client_class_name = server.client_class_name
@ -27,8 +70,31 @@ class ControlSession:
self._parser: Optional[CommandParser] = None self._parser: Optional[CommandParser] = None
self._subparsers = None self._subparsers = None
def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
**kwargs) -> CommandParser: **kwargs) -> CommandParser:
"""
Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
Args:
name:
The command name; passed directly into the `add_parser` method.
prog (optional):
Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
equal to the `name` argument.
short_help (optional):
Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
`long_help` argument is present; in that case the `long_help` argument is passed as `help`.
long_help (optional):
Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
**kwargs (optional):
Any keyword-arguments to directly pass into the `add_parser` method.
Returns:
An instance of the `CommandParser` class representing the newly added control command.
"""
if prog is None: if prog is None:
prog = name prog = name
kwargs.setdefault('help', short_help or long_help) kwargs.setdefault('help', short_help or long_help)
@ -36,9 +102,16 @@ class ControlSession:
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs) return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
def _add_base_commands(self) -> None: def _add_base_commands(self) -> None:
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) """
self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
self._add_parser_command(
These include commands mapping to the following pool methods:
- __str__
- pool_size (get/set property)
- num_running
"""
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
self._add_command(
CMD.POOL_SIZE, CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.", short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter formatter_class=HelpFormatter
@ -47,36 +120,57 @@ class ControlSession:
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
) )
self._add_parser_command( self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)
)
def _add_simple_commands(self) -> None: def _add_simple_commands(self) -> None:
self._add_parser_command( """
Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
These include commands mapping to the following pool methods:
- start
- stop
- stop_all
- func_name
"""
self._add_command(
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start) CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
).add_optional_num_argument( ).add_optional_num_argument(
help="Number of tasks to start." help="Number of tasks to start."
) )
self._add_parser_command( self._add_command(
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop) CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
).add_optional_num_argument( ).add_optional_num_argument(
help="Number of tasks to stop." help="Number of tasks to stop."
) )
self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
self._add_parser_command( self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
)
def _add_advanced_commands(self) -> None: def _add_advanced_commands(self) -> None:
"""
Adds the commands that are only supported, if a `TaskPool` object is controlled.
These include commands mapping to the following pool methods:
- ...
"""
raise NotImplementedError raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None: def _init_parser(self, client_terminal_width: int) -> None:
"""
Initializes and fully configures the `CommandParser` responsible for handling the input.
Depending on what specific task pool class is controlled by the server, different commands are added.
Args:
client_terminal_width:
The number of columns of the client's terminal to be able to nicely format messages from the parser.
"""
parser_kwargs = { parser_kwargs = {
'prog': '', 'prog': '',
SESSION_PARSER_WRITER: self._writer, SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
} }
self._parser = CommandParser(**parser_kwargs) self._parser = CommandParser(**parser_kwargs)
self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
self._add_base_commands() self._add_base_commands()
if isinstance(self._pool, TaskPool): if isinstance(self._pool, TaskPool):
self._add_advanced_commands() self._add_advanced_commands()
@ -88,6 +182,11 @@ class ControlSession:
raise NotATaskPool(f"Not a task pool instance: {self._pool}") raise NotATaskPool(f"Not a task pool instance: {self._pool}")
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
"""
This method must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH]) self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
@ -95,14 +194,31 @@ class ControlSession:
await self._writer.drain() await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None: async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
"""
Acts as a wrapper around a call to a specific task pool method.
The method is called and any exception is caught and saved. If there is no output and no exception caught, a
generic confirmation message is sent back to the client. Otherwise the output or a string representation of
the exception caught is sent back.
Args:
func:
Reference to the task pool method.
*args (optional):
Any positional arguments to call the method with.
*+kwargs (optional):
Any keyword-arguments to call the method with.
"""
output = await return_or_exception(func, *args, **kwargs) output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode()) self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None: async def _cmd_name(self, **_kwargs) -> None:
"""Maps to the `__str__` method of any task pool class."""
log.debug("%s requests task pool name", self._client_class_name) log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool) await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None: async def _cmd_pool_size(self, **kwargs) -> None:
"""Maps to the `pool_size` property of any task pool class."""
num = kwargs.get(NUM) num = kwargs.get(NUM)
if num is None: if num is None:
log.debug("%s requests pool size", self._client_class_name) log.debug("%s requests pool size", self._client_class_name)
@ -112,45 +228,73 @@ class ControlSession:
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num) await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None: async def _cmd_num_running(self, **_kwargs) -> None:
"""Maps to the `num_running` property of any task pool class."""
log.debug("%s requests number of running tasks", self._client_class_name) log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool) await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None: async def _cmd_start(self, **kwargs) -> None:
"""Maps to the `start` method of the `SimpleTaskPool` class."""
num = kwargs[NUM] num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num)) log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num) await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None: async def _cmd_stop(self, **kwargs) -> None:
"""Maps to the `stop` method of the `SimpleTaskPool` class."""
num = kwargs[NUM] num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num)) log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num) await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None: async def _cmd_stop_all(self, **_kwargs) -> None:
"""Maps to the `stop_all` method of the `SimpleTaskPool` class."""
log.debug("%s requests stopping all tasks", self._client_class_name) log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all) await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None: async def _cmd_func_name(self, **_kwargs) -> None:
"""Maps to the `func_name` method of the `SimpleTaskPool` class."""
log.debug("%s requests pool function name", self._client_class_name) log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool) await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
async def _execute_command(self, args: Namespace) -> None: async def _execute_command(self, **kwargs) -> None:
args = vars(args) """
cmd: str = args.pop(CMD.CMD, None) Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
if cmd is not None:
method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') Args:
await method(**args) **kwargs:
Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
simply passed into the method determined from the command name.
"""
method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
await method(**kwargs)
async def _parse_command(self, msg: str) -> None: async def _parse_command(self, msg: str) -> None:
"""
Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
`CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
dictionary of keyword-arguments returned by the `CommandParser` passed into it.
Args:
msg:
The non-empty string read from the client stream.
"""
try: try:
args = self._parser.parse_args(msg.split(' ')) kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e: except ArgumentError as e:
self._writer.write(str(e).encode()) self._writer.write(str(e).encode())
return return
except HelpRequested: except HelpRequested:
return return
await self._execute_command(args) await self._execute_command(**kwargs)
async def listen(self) -> None: async def listen(self) -> None:
"""
Enters the main control loop that only ends if either the server or the client disconnect.
Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
This method should be called, when the client connection was established and the handshake was successful.
It will obviously block indefinitely.
"""
while self._control_server.is_serving(): while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg: if not msg:

View File

@ -1,3 +1,24 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.server` module.
"""
import asyncio import asyncio
import logging import logging
from pathlib import Path from pathlib import Path

324
tests/test_session.py Normal file
View File

@ -0,0 +1,324 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Unittests for the `asyncio_taskpool.session` module.
"""
import json
from argparse import ArgumentError, Namespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER
from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = session.log.level
session.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
session.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
self.mock_client_class_name = FOO + BAR
self.mock_server = MagicMock(pool=self.mock_pool,
client_class_name=self.mock_client_class_name)
self.mock_reader = MagicMock()
self.mock_writer = MagicMock()
self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
def test_init(self):
self.assertEqual(self.mock_server, self.session._control_server)
self.assertEqual(self.mock_pool, self.session._pool)
self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser)
self.assertIsNone(self.session._subparsers)
def test__add_command(self):
expected_output = 123456
mock_add_parser = MagicMock(return_value=expected_output)
self.session._subparsers = MagicMock(add_parser=mock_add_parser)
self.session._parser = MagicMock()
name, prog, short_help, long_help = 'abc', None, 'short123', None
kwargs = {'x': 1, 'y': 2}
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
prog, long_help = 'ffffff', 'so long, wow'
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
parent=self.session._parser, **kwargs)
mock_add_parser.reset_mock()
short_help = None
output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
self.assertEqual(expected_output, output)
mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
parent=self.session._parser, **kwargs)
@patch.object(session, 'get_first_doc_line')
@patch.object(session.ControlSession, '_add_command')
def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
self.assertIsNone(self.session._add_base_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
mock__add_command.reset_mock()
mock_get_first_doc_line.reset_mock()
self.assertIsNone(self.session._add_simple_commands())
mock__add_command.assert_called()
mock_get_first_doc_line.assert_called()
with self.assertRaises(NotImplementedError):
self.session._add_advanced_commands()
@patch.object(session.ControlSession, '_add_simple_commands')
@patch.object(session.ControlSession, '_add_advanced_commands')
@patch.object(session.ControlSession, '_add_base_commands')
@patch.object(session, 'CommandParser')
def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
mock_command_parser_cls.return_value = mock_parser = MagicMock()
self.session._pool = TaskPool()
width = 1234
expected_parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
}
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_called_once_with()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
async def fake_coroutine(): pass
self.session._pool = SimpleTaskPool(fake_coroutine)
self.assertIsNone(self.session._init_parser(width))
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_called_once_with()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
class FakeTaskPool(BaseTaskPool):
pass
self.session._pool = FakeTaskPool()
with self.assertRaises(UnknownTaskPoolClass):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
mock_command_parser_cls.reset_mock()
mock_parser.add_subparsers.reset_mock()
mock__add_base_commands.reset_mock()
mock__add_advanced_commands.reset_mock()
mock__add_simple_commands.reset_mock()
self.session._pool = MagicMock()
with self.assertRaises(NotATaskPool):
self.session._init_parser(width)
mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
mock__add_base_commands.assert_called_once_with()
mock__add_advanced_commands.assert_not_called()
mock__add_simple_commands.assert_not_called()
@patch.object(session.ControlSession, '_init_parser')
async def test_client_handshake(self, mock__init_parser: MagicMock):
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock()
self.assertIsNone(await self.session.client_handshake())
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock__init_parser.assert_called_once_with(width)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session, 'return_or_exception')
async def test__write_function_output(self, mock_return_or_exception: MagicMock):
self.mock_writer.write = MagicMock()
mock_return_or_exception.return_value = None
func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(b"ok")
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
mock_return_or_exception.return_value = output = MagicMock()
self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
self.mock_writer.write.assert_called_once_with(str(output).encode())
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_name())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fset, self.session._pool, num
)
mock__write_function_output.reset_mock()
kwargs.pop(session.NUM)
self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.pool_size.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_num_running())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.num_running.fget, self.session._pool
)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_start(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_start(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
num = 12345
kwargs = {session.NUM: num, FOO: BAR}
self.assertIsNone(await self.session._cmd_stop(**kwargs))
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_stop_all())
mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
@patch.object(session.ControlSession, '_write_function_output')
async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
self.assertIsNone(await self.session._cmd_func_name())
mock__write_function_output.assert_awaited_once_with(
self.mock_pool.__class__.func_name.fget, self.session._pool
)
async def test__execute_command(self):
mock_method = AsyncMock()
cmd = 'this-is-a-test'
setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
kwargs = {FOO: BAR, 'hello': 'python'}
self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
mock_method.assert_awaited_once_with(**kwargs)
@patch.object(session.ControlSession, '_execute_command')
async def test__parse_command(self, mock__execute_command: AsyncMock):
msg = 'asdf asd as a'
kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_awaited_once_with(**kwargs)
mock__execute_command.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__execute_command.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
mock__execute_command.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty():
self.mock_reader.read.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg)
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.drain.assert_not_awaited()