diff --git a/.coveragerc b/.coveragerc index 11a7f12..f70c551 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,5 +8,8 @@ omit = fail_under = 100 show_missing = True skip_covered = False +exclude_lines = + if TYPE_CHECKING: + if __name__ == ['"]__main__['"]: omit = tests/* diff --git a/setup.cfg b/setup.cfg index 5c2e218..c0caba2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.3.2 +version = 0.3.3 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py index b35a0f8..6db1d01 100644 --- a/src/asyncio_taskpool/session.py +++ b/src/asyncio_taskpool/session.py @@ -1,6 +1,27 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +This module contains the the definition of the control session class used by the control server. +""" + + import logging import json -from argparse import ArgumentError, HelpFormatter, Namespace +from argparse import ArgumentError, HelpFormatter from asyncio.streams import StreamReader, StreamWriter from typing import Callable, Optional, Union, TYPE_CHECKING @@ -18,7 +39,29 @@ log = logging.getLogger(__name__) class ControlSession: + """ + This class defines the API for controlling a task pool instance from the outside. + + The commands received from a connected client are translated into method calls on the task pool instance. + A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream. + """ + def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: + """ + Instantiation should happen once a client connection to the control server has already been established. + + For more convenient/efficient access, some of the server's properties are saved in separate attributes. + The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during + initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured. + + Args: + server: + The instance of a `ControlServer` subclass starting the session. + reader: + The `asyncio.StreamReader` created when a client connected to the server. + writer: + The `asyncio.StreamWriter` created when a client connected to the server. + """ self._control_server: 'ControlServer' = server self._pool: Union[TaskPool, SimpleTaskPool] = server.pool self._client_class_name = server.client_class_name @@ -27,8 +70,31 @@ class ControlSession: self._parser: Optional[CommandParser] = None self._subparsers = None - def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, - **kwargs) -> CommandParser: + def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, + **kwargs) -> CommandParser: + """ + Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance. + + Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument. + + Args: + name: + The command name; passed directly into the `add_parser` method. + prog (optional): + Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set + equal to the `name` argument. + short_help (optional): + Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the + `long_help` argument is present; in that case the `long_help` argument is passed as `help`. + long_help (optional): + Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and + the `short_help` argument is present; in that case the `short_help` argument is passed as `description`. + **kwargs (optional): + Any keyword-arguments to directly pass into the `add_parser` method. + + Returns: + An instance of the `CommandParser` class representing the newly added control command. + """ if prog is None: prog = name kwargs.setdefault('help', short_help or long_help) @@ -36,9 +102,16 @@ class ControlSession: return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs) def _add_base_commands(self) -> None: - self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) - self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) - self._add_parser_command( + """ + Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled. + + These include commands mapping to the following pool methods: + - __str__ + - pool_size (get/set property) + - num_running + """ + self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) + self._add_command( CMD.POOL_SIZE, short_help="Get/set the maximum number of tasks in the pool.", formatter_class=HelpFormatter @@ -47,36 +120,57 @@ class ControlSession: help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" ) - self._add_parser_command( - CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget) - ) + self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)) def _add_simple_commands(self) -> None: - self._add_parser_command( + """ + Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled. + + These include commands mapping to the following pool methods: + - start + - stop + - stop_all + - func_name + """ + self._add_command( CMD.START, short_help=get_first_doc_line(self._pool.__class__.start) ).add_optional_num_argument( help="Number of tasks to start." ) - self._add_parser_command( + self._add_command( CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop) ).add_optional_num_argument( help="Number of tasks to stop." ) - self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) - self._add_parser_command( - CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget) - ) + self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) + self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)) def _add_advanced_commands(self) -> None: + """ + Adds the commands that are only supported, if a `TaskPool` object is controlled. + + These include commands mapping to the following pool methods: + - ... + """ raise NotImplementedError def _init_parser(self, client_terminal_width: int) -> None: + """ + Initializes and fully configures the `CommandParser` responsible for handling the input. + + Depending on what specific task pool class is controlled by the server, different commands are added. + + Args: + client_terminal_width: + The number of columns of the client's terminal to be able to nicely format messages from the parser. + """ parser_kwargs = { 'prog': '', SESSION_PARSER_WRITER: self._writer, CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, } self._parser = CommandParser(**parser_kwargs) + self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) self._add_base_commands() if isinstance(self._pool, TaskPool): self._add_advanced_commands() @@ -88,6 +182,11 @@ class ControlSession: raise NotATaskPool(f"Not a task pool instance: {self._pool}") async def client_handshake(self) -> None: + """ + This method must be invoked before starting any other client interaction. + + Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured. + """ client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) log.debug("%s connected", self._client_class_name) self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH]) @@ -95,14 +194,31 @@ class ControlSession: await self._writer.drain() async def _write_function_output(self, func: Callable, *args, **kwargs) -> None: + """ + Acts as a wrapper around a call to a specific task pool method. + + The method is called and any exception is caught and saved. If there is no output and no exception caught, a + generic confirmation message is sent back to the client. Otherwise the output or a string representation of + the exception caught is sent back. + + Args: + func: + Reference to the task pool method. + *args (optional): + Any positional arguments to call the method with. + *+kwargs (optional): + Any keyword-arguments to call the method with. + """ output = await return_or_exception(func, *args, **kwargs) self._writer.write(b"ok" if output is None else str(output).encode()) async def _cmd_name(self, **_kwargs) -> None: + """Maps to the `__str__` method of any task pool class.""" log.debug("%s requests task pool name", self._client_class_name) await self._write_function_output(self._pool.__class__.__str__, self._pool) async def _cmd_pool_size(self, **kwargs) -> None: + """Maps to the `pool_size` property of any task pool class.""" num = kwargs.get(NUM) if num is None: log.debug("%s requests pool size", self._client_class_name) @@ -112,45 +228,73 @@ class ControlSession: await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num) async def _cmd_num_running(self, **_kwargs) -> None: + """Maps to the `num_running` property of any task pool class.""" log.debug("%s requests number of running tasks", self._client_class_name) await self._write_function_output(self._pool.__class__.num_running.fget, self._pool) async def _cmd_start(self, **kwargs) -> None: + """Maps to the `start` method of the `SimpleTaskPool` class.""" num = kwargs[NUM] log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num)) await self._write_function_output(self._pool.start, num) async def _cmd_stop(self, **kwargs) -> None: + """Maps to the `stop` method of the `SimpleTaskPool` class.""" num = kwargs[NUM] log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num)) await self._write_function_output(self._pool.stop, num) async def _cmd_stop_all(self, **_kwargs) -> None: + """Maps to the `stop_all` method of the `SimpleTaskPool` class.""" log.debug("%s requests stopping all tasks", self._client_class_name) await self._write_function_output(self._pool.stop_all) async def _cmd_func_name(self, **_kwargs) -> None: + """Maps to the `func_name` method of the `SimpleTaskPool` class.""" log.debug("%s requests pool function name", self._client_class_name) await self._write_function_output(self._pool.__class__.func_name.fget, self._pool) - async def _execute_command(self, args: Namespace) -> None: - args = vars(args) - cmd: str = args.pop(CMD.CMD, None) - if cmd is not None: - method = getattr(self, f'_cmd_{cmd.replace("-", "_")}') - await method(**args) + async def _execute_command(self, **kwargs) -> None: + """ + Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it. + + Args: + **kwargs: + Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is + simply passed into the method determined from the command name. + """ + method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}') + await method(**kwargs) async def _parse_command(self, msg: str) -> None: + """ + Takes a message from the client and attempts to parse it. + + If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the + `CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire + dictionary of keyword-arguments returned by the `CommandParser` passed into it. + + Args: + msg: + The non-empty string read from the client stream. + """ try: - args = self._parser.parse_args(msg.split(' ')) + kwargs = vars(self._parser.parse_args(msg.split(' '))) except ArgumentError as e: self._writer.write(str(e).encode()) return except HelpRequested: return - await self._execute_command(args) + await self._execute_command(**kwargs) async def listen(self) -> None: + """ + Enters the main control loop that only ends if either the server or the client disconnect. + + Messages from the client are read and passed into the `_parse_command` method, which handles the rest. + This method should be called, when the client connection was established and the handshake was successful. + It will obviously block indefinitely. + """ while self._control_server.is_serving(): msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() if not msg: diff --git a/tests/test_server.py b/tests/test_server.py index 1e5030b..0d15e0b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,3 +1,24 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +Unittests for the `asyncio_taskpool.server` module. +""" + + import asyncio import logging from pathlib import Path diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..921e06d --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,324 @@ +__author__ = "Daniil Fajnberg" +__copyright__ = "Copyright © 2022 Daniil Fajnberg" +__license__ = """GNU LGPLv3.0 + +This file is part of asyncio-taskpool. + +asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of +version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. + +asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the GNU Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. +If not, see .""" + +__doc__ = """ +Unittests for the `asyncio_taskpool.session` module. +""" + + +import json +from argparse import ArgumentError, Namespace +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch, call + +from asyncio_taskpool import session +from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER +from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass +from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool + + +FOO, BAR = 'foo', 'bar' + + +class ControlServerTestCase(IsolatedAsyncioTestCase): + log_lvl: int + + @classmethod + def setUpClass(cls) -> None: + cls.log_lvl = session.log.level + session.log.setLevel(999) + + @classmethod + def tearDownClass(cls) -> None: + session.log.setLevel(cls.log_lvl) + + def setUp(self) -> None: + self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock())) + self.mock_client_class_name = FOO + BAR + self.mock_server = MagicMock(pool=self.mock_pool, + client_class_name=self.mock_client_class_name) + self.mock_reader = MagicMock() + self.mock_writer = MagicMock() + self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer) + + def test_init(self): + self.assertEqual(self.mock_server, self.session._control_server) + self.assertEqual(self.mock_pool, self.session._pool) + self.assertEqual(self.mock_client_class_name, self.session._client_class_name) + self.assertEqual(self.mock_reader, self.session._reader) + self.assertEqual(self.mock_writer, self.session._writer) + self.assertIsNone(self.session._parser) + self.assertIsNone(self.session._subparsers) + + def test__add_command(self): + expected_output = 123456 + mock_add_parser = MagicMock(return_value=expected_output) + self.session._subparsers = MagicMock(add_parser=mock_add_parser) + self.session._parser = MagicMock() + name, prog, short_help, long_help = 'abc', None, 'short123', None + kwargs = {'x': 1, 'y': 2} + output = self.session._add_command(name, prog, short_help, long_help, **kwargs) + self.assertEqual(expected_output, output) + mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help, + parent=self.session._parser, **kwargs) + + mock_add_parser.reset_mock() + + prog, long_help = 'ffffff', 'so long, wow' + output = self.session._add_command(name, prog, short_help, long_help, **kwargs) + self.assertEqual(expected_output, output) + mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help, + parent=self.session._parser, **kwargs) + + mock_add_parser.reset_mock() + + short_help = None + output = self.session._add_command(name, prog, short_help, long_help, **kwargs) + self.assertEqual(expected_output, output) + mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help, + parent=self.session._parser, **kwargs) + + @patch.object(session, 'get_first_doc_line') + @patch.object(session.ControlSession, '_add_command') + def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock): + self.assertIsNone(self.session._add_base_commands()) + mock__add_command.assert_called() + mock_get_first_doc_line.assert_called() + + mock__add_command.reset_mock() + mock_get_first_doc_line.reset_mock() + + self.assertIsNone(self.session._add_simple_commands()) + mock__add_command.assert_called() + mock_get_first_doc_line.assert_called() + + with self.assertRaises(NotImplementedError): + self.session._add_advanced_commands() + + @patch.object(session.ControlSession, '_add_simple_commands') + @patch.object(session.ControlSession, '_add_advanced_commands') + @patch.object(session.ControlSession, '_add_base_commands') + @patch.object(session, 'CommandParser') + def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock, + mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock): + mock_command_parser_cls.return_value = mock_parser = MagicMock() + self.session._pool = TaskPool() + width = 1234 + expected_parser_kwargs = { + 'prog': '', + SESSION_PARSER_WRITER: self.mock_writer, + CLIENT_INFO.TERMINAL_WIDTH: width, + } + self.assertIsNone(self.session._init_parser(width)) + mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) + mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) + mock__add_base_commands.assert_called_once_with() + mock__add_advanced_commands.assert_called_once_with() + mock__add_simple_commands.assert_not_called() + + mock_command_parser_cls.reset_mock() + mock_parser.add_subparsers.reset_mock() + mock__add_base_commands.reset_mock() + mock__add_advanced_commands.reset_mock() + mock__add_simple_commands.reset_mock() + + async def fake_coroutine(): pass + + self.session._pool = SimpleTaskPool(fake_coroutine) + self.assertIsNone(self.session._init_parser(width)) + mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) + mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) + mock__add_base_commands.assert_called_once_with() + mock__add_advanced_commands.assert_not_called() + mock__add_simple_commands.assert_called_once_with() + + mock_command_parser_cls.reset_mock() + mock_parser.add_subparsers.reset_mock() + mock__add_base_commands.reset_mock() + mock__add_advanced_commands.reset_mock() + mock__add_simple_commands.reset_mock() + + class FakeTaskPool(BaseTaskPool): + pass + + self.session._pool = FakeTaskPool() + with self.assertRaises(UnknownTaskPoolClass): + self.session._init_parser(width) + mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) + mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) + mock__add_base_commands.assert_called_once_with() + mock__add_advanced_commands.assert_not_called() + mock__add_simple_commands.assert_not_called() + + mock_command_parser_cls.reset_mock() + mock_parser.add_subparsers.reset_mock() + mock__add_base_commands.reset_mock() + mock__add_advanced_commands.reset_mock() + mock__add_simple_commands.reset_mock() + + self.session._pool = MagicMock() + with self.assertRaises(NotATaskPool): + self.session._init_parser(width) + mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) + mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) + mock__add_base_commands.assert_called_once_with() + mock__add_advanced_commands.assert_not_called() + mock__add_simple_commands.assert_not_called() + + @patch.object(session.ControlSession, '_init_parser') + async def test_client_handshake(self, mock__init_parser: MagicMock): + width = 5678 + msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' ' + mock_read = AsyncMock(return_value=msg.encode()) + self.mock_reader.read = mock_read + self.mock_writer.drain = AsyncMock() + self.assertIsNone(await self.session.client_handshake()) + mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) + mock__init_parser.assert_called_once_with(width) + self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) + self.mock_writer.drain.assert_awaited_once_with() + + @patch.object(session, 'return_or_exception') + async def test__write_function_output(self, mock_return_or_exception: MagicMock): + self.mock_writer.write = MagicMock() + mock_return_or_exception.return_value = None + func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'} + self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) + mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) + self.mock_writer.write.assert_called_once_with(b"ok") + + mock_return_or_exception.reset_mock() + self.mock_writer.write.reset_mock() + + mock_return_or_exception.return_value = output = MagicMock() + self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) + mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) + self.mock_writer.write.assert_called_once_with(str(output).encode()) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_name(self, mock__write_function_output: AsyncMock): + self.assertIsNone(await self.session._cmd_name()) + mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock): + num = 12345 + kwargs = {session.NUM: num, FOO: BAR} + self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) + mock__write_function_output.assert_awaited_once_with( + self.mock_pool.__class__.pool_size.fset, self.session._pool, num + ) + + mock__write_function_output.reset_mock() + + kwargs.pop(session.NUM) + self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) + mock__write_function_output.assert_awaited_once_with( + self.mock_pool.__class__.pool_size.fget, self.session._pool + ) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_num_running(self, mock__write_function_output: AsyncMock): + self.assertIsNone(await self.session._cmd_num_running()) + mock__write_function_output.assert_awaited_once_with( + self.mock_pool.__class__.num_running.fget, self.session._pool + ) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_start(self, mock__write_function_output: AsyncMock): + num = 12345 + kwargs = {session.NUM: num, FOO: BAR} + self.assertIsNone(await self.session._cmd_start(**kwargs)) + mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_stop(self, mock__write_function_output: AsyncMock): + num = 12345 + kwargs = {session.NUM: num, FOO: BAR} + self.assertIsNone(await self.session._cmd_stop(**kwargs)) + mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock): + self.assertIsNone(await self.session._cmd_stop_all()) + mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all) + + @patch.object(session.ControlSession, '_write_function_output') + async def test__cmd_func_name(self, mock__write_function_output: AsyncMock): + self.assertIsNone(await self.session._cmd_func_name()) + mock__write_function_output.assert_awaited_once_with( + self.mock_pool.__class__.func_name.fget, self.session._pool + ) + + async def test__execute_command(self): + mock_method = AsyncMock() + cmd = 'this-is-a-test' + setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method) + kwargs = {FOO: BAR, 'hello': 'python'} + self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs)) + mock_method.assert_awaited_once_with(**kwargs) + + @patch.object(session.ControlSession, '_execute_command') + async def test__parse_command(self, mock__execute_command: AsyncMock): + msg = 'asdf asd as a' + kwargs = {FOO: BAR, 'hello': 'python'} + mock_parse_args = MagicMock(return_value=Namespace(**kwargs)) + self.session._parser = MagicMock(parse_args=mock_parse_args) + self.mock_writer.write = MagicMock() + self.assertIsNone(await self.session._parse_command(msg)) + mock_parse_args.assert_called_once_with(msg.split(' ')) + self.mock_writer.write.assert_not_called() + mock__execute_command.assert_awaited_once_with(**kwargs) + + mock__execute_command.reset_mock() + mock_parse_args.reset_mock() + + mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") + self.assertIsNone(await self.session._parse_command(msg)) + mock_parse_args.assert_called_once_with(msg.split(' ')) + self.mock_writer.write.assert_called_once_with(str(exc).encode()) + mock__execute_command.assert_not_awaited() + + self.mock_writer.write.reset_mock() + mock_parse_args.reset_mock() + + mock_parse_args.side_effect = HelpRequested() + self.assertIsNone(await self.session._parse_command(msg)) + mock_parse_args.assert_called_once_with(msg.split(' ')) + self.mock_writer.write.assert_not_called() + mock__execute_command.assert_not_awaited() + + @patch.object(session.ControlSession, '_parse_command') + async def test_listen(self, mock__parse_command: AsyncMock): + def make_reader_return_empty(): + self.mock_reader.read.return_value = b'' + self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) + msg = "fascinating" + self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) + self.assertIsNone(await self.session.listen()) + self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) + mock__parse_command.assert_awaited_once_with(msg) + self.mock_writer.drain.assert_awaited_once_with() + + self.mock_reader.read.reset_mock() + mock__parse_command.reset_mock() + self.mock_writer.drain.reset_mock() + + self.mock_server.is_serving = MagicMock(return_value=False) + self.assertIsNone(await self.session.listen()) + self.mock_reader.read.assert_not_awaited() + mock__parse_command.assert_not_awaited() + self.mock_writer.drain.assert_not_awaited()