diff --git a/.coveragerc b/.coveragerc
index 11a7f12..f70c551 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -8,5 +8,8 @@ omit =
fail_under = 100
show_missing = True
skip_covered = False
+exclude_lines =
+ if TYPE_CHECKING:
+ if __name__ == ['"]__main__['"]:
omit =
tests/*
diff --git a/setup.cfg b/setup.cfg
index 5c2e218..c0caba2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
-version = 0.3.2
+version = 0.3.3
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks
diff --git a/src/asyncio_taskpool/session.py b/src/asyncio_taskpool/session.py
index b35a0f8..6db1d01 100644
--- a/src/asyncio_taskpool/session.py
+++ b/src/asyncio_taskpool/session.py
@@ -1,6 +1,27 @@
+__author__ = "Daniil Fajnberg"
+__copyright__ = "Copyright © 2022 Daniil Fajnberg"
+__license__ = """GNU LGPLv3.0
+
+This file is part of asyncio-taskpool.
+
+asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
+version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
+
+asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
+If not, see ."""
+
+__doc__ = """
+This module contains the the definition of the control session class used by the control server.
+"""
+
+
import logging
import json
-from argparse import ArgumentError, HelpFormatter, Namespace
+from argparse import ArgumentError, HelpFormatter
from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING
@@ -18,7 +39,29 @@ log = logging.getLogger(__name__)
class ControlSession:
+ """
+ This class defines the API for controlling a task pool instance from the outside.
+
+ The commands received from a connected client are translated into method calls on the task pool instance.
+ A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream.
+ """
+
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
+ """
+ Instantiation should happen once a client connection to the control server has already been established.
+
+ For more convenient/efficient access, some of the server's properties are saved in separate attributes.
+ The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
+ initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured.
+
+ Args:
+ server:
+ The instance of a `ControlServer` subclass starting the session.
+ reader:
+ The `asyncio.StreamReader` created when a client connected to the server.
+ writer:
+ The `asyncio.StreamWriter` created when a client connected to the server.
+ """
self._control_server: 'ControlServer' = server
self._pool: Union[TaskPool, SimpleTaskPool] = server.pool
self._client_class_name = server.client_class_name
@@ -27,8 +70,31 @@ class ControlSession:
self._parser: Optional[CommandParser] = None
self._subparsers = None
- def _add_parser_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
- **kwargs) -> CommandParser:
+ def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None,
+ **kwargs) -> CommandParser:
+ """
+ Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance.
+
+ Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument.
+
+ Args:
+ name:
+ The command name; passed directly into the `add_parser` method.
+ prog (optional):
+ Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set
+ equal to the `name` argument.
+ short_help (optional):
+ Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the
+ `long_help` argument is present; in that case the `long_help` argument is passed as `help`.
+ long_help (optional):
+ Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and
+ the `short_help` argument is present; in that case the `short_help` argument is passed as `description`.
+ **kwargs (optional):
+ Any keyword-arguments to directly pass into the `add_parser` method.
+
+ Returns:
+ An instance of the `CommandParser` class representing the newly added control command.
+ """
if prog is None:
prog = name
kwargs.setdefault('help', short_help or long_help)
@@ -36,9 +102,16 @@ class ControlSession:
return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs)
def _add_base_commands(self) -> None:
- self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
- self._add_parser_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
- self._add_parser_command(
+ """
+ Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled.
+
+ These include commands mapping to the following pool methods:
+ - __str__
+ - pool_size (get/set property)
+ - num_running
+ """
+ self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
+ self._add_command(
CMD.POOL_SIZE,
short_help="Get/set the maximum number of tasks in the pool.",
formatter_class=HelpFormatter
@@ -47,36 +120,57 @@ class ControlSession:
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
)
- self._add_parser_command(
- CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)
- )
+ self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
def _add_simple_commands(self) -> None:
- self._add_parser_command(
+ """
+ Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled.
+
+ These include commands mapping to the following pool methods:
+ - start
+ - stop
+ - stop_all
+ - func_name
+ """
+ self._add_command(
CMD.START, short_help=get_first_doc_line(self._pool.__class__.start)
).add_optional_num_argument(
help="Number of tasks to start."
)
- self._add_parser_command(
+ self._add_command(
CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop)
).add_optional_num_argument(
help="Number of tasks to stop."
)
- self._add_parser_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
- self._add_parser_command(
- CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
- )
+ self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all))
+ self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget))
def _add_advanced_commands(self) -> None:
+ """
+ Adds the commands that are only supported, if a `TaskPool` object is controlled.
+
+ These include commands mapping to the following pool methods:
+ - ...
+ """
raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None:
+ """
+ Initializes and fully configures the `CommandParser` responsible for handling the input.
+
+ Depending on what specific task pool class is controlled by the server, different commands are added.
+
+ Args:
+ client_terminal_width:
+ The number of columns of the client's terminal to be able to nicely format messages from the parser.
+ """
parser_kwargs = {
'prog': '',
SESSION_PARSER_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width,
}
self._parser = CommandParser(**parser_kwargs)
+ self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD)
self._add_base_commands()
if isinstance(self._pool, TaskPool):
self._add_advanced_commands()
@@ -88,6 +182,11 @@ class ControlSession:
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
async def client_handshake(self) -> None:
+ """
+ This method must be invoked before starting any other client interaction.
+
+ Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured.
+ """
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name)
self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH])
@@ -95,14 +194,31 @@ class ControlSession:
await self._writer.drain()
async def _write_function_output(self, func: Callable, *args, **kwargs) -> None:
+ """
+ Acts as a wrapper around a call to a specific task pool method.
+
+ The method is called and any exception is caught and saved. If there is no output and no exception caught, a
+ generic confirmation message is sent back to the client. Otherwise the output or a string representation of
+ the exception caught is sent back.
+
+ Args:
+ func:
+ Reference to the task pool method.
+ *args (optional):
+ Any positional arguments to call the method with.
+ *+kwargs (optional):
+ Any keyword-arguments to call the method with.
+ """
output = await return_or_exception(func, *args, **kwargs)
self._writer.write(b"ok" if output is None else str(output).encode())
async def _cmd_name(self, **_kwargs) -> None:
+ """Maps to the `__str__` method of any task pool class."""
log.debug("%s requests task pool name", self._client_class_name)
await self._write_function_output(self._pool.__class__.__str__, self._pool)
async def _cmd_pool_size(self, **kwargs) -> None:
+ """Maps to the `pool_size` property of any task pool class."""
num = kwargs.get(NUM)
if num is None:
log.debug("%s requests pool size", self._client_class_name)
@@ -112,45 +228,73 @@ class ControlSession:
await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num)
async def _cmd_num_running(self, **_kwargs) -> None:
+ """Maps to the `num_running` property of any task pool class."""
log.debug("%s requests number of running tasks", self._client_class_name)
await self._write_function_output(self._pool.__class__.num_running.fget, self._pool)
async def _cmd_start(self, **kwargs) -> None:
+ """Maps to the `start` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.start, num)
async def _cmd_stop(self, **kwargs) -> None:
+ """Maps to the `stop` method of the `SimpleTaskPool` class."""
num = kwargs[NUM]
log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num))
await self._write_function_output(self._pool.stop, num)
async def _cmd_stop_all(self, **_kwargs) -> None:
+ """Maps to the `stop_all` method of the `SimpleTaskPool` class."""
log.debug("%s requests stopping all tasks", self._client_class_name)
await self._write_function_output(self._pool.stop_all)
async def _cmd_func_name(self, **_kwargs) -> None:
+ """Maps to the `func_name` method of the `SimpleTaskPool` class."""
log.debug("%s requests pool function name", self._client_class_name)
await self._write_function_output(self._pool.__class__.func_name.fget, self._pool)
- async def _execute_command(self, args: Namespace) -> None:
- args = vars(args)
- cmd: str = args.pop(CMD.CMD, None)
- if cmd is not None:
- method = getattr(self, f'_cmd_{cmd.replace("-", "_")}')
- await method(**args)
+ async def _execute_command(self, **kwargs) -> None:
+ """
+ Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it.
+
+ Args:
+ **kwargs:
+ Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is
+ simply passed into the method determined from the command name.
+ """
+ method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}')
+ await method(**kwargs)
async def _parse_command(self, msg: str) -> None:
+ """
+ Takes a message from the client and attempts to parse it.
+
+ If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the
+ `CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire
+ dictionary of keyword-arguments returned by the `CommandParser` passed into it.
+
+ Args:
+ msg:
+ The non-empty string read from the client stream.
+ """
try:
- args = self._parser.parse_args(msg.split(' '))
+ kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
self._writer.write(str(e).encode())
return
except HelpRequested:
return
- await self._execute_command(args)
+ await self._execute_command(**kwargs)
async def listen(self) -> None:
+ """
+ Enters the main control loop that only ends if either the server or the client disconnect.
+
+ Messages from the client are read and passed into the `_parse_command` method, which handles the rest.
+ This method should be called, when the client connection was established and the handshake was successful.
+ It will obviously block indefinitely.
+ """
while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
if not msg:
diff --git a/tests/test_server.py b/tests/test_server.py
index 1e5030b..0d15e0b 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -1,3 +1,24 @@
+__author__ = "Daniil Fajnberg"
+__copyright__ = "Copyright © 2022 Daniil Fajnberg"
+__license__ = """GNU LGPLv3.0
+
+This file is part of asyncio-taskpool.
+
+asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
+version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
+
+asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
+If not, see ."""
+
+__doc__ = """
+Unittests for the `asyncio_taskpool.server` module.
+"""
+
+
import asyncio
import logging
from pathlib import Path
diff --git a/tests/test_session.py b/tests/test_session.py
new file mode 100644
index 0000000..921e06d
--- /dev/null
+++ b/tests/test_session.py
@@ -0,0 +1,324 @@
+__author__ = "Daniil Fajnberg"
+__copyright__ = "Copyright © 2022 Daniil Fajnberg"
+__license__ = """GNU LGPLv3.0
+
+This file is part of asyncio-taskpool.
+
+asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
+version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
+
+asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the GNU Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
+If not, see ."""
+
+__doc__ = """
+Unittests for the `asyncio_taskpool.session` module.
+"""
+
+
+import json
+from argparse import ArgumentError, Namespace
+from unittest import IsolatedAsyncioTestCase
+from unittest.mock import AsyncMock, MagicMock, patch, call
+
+from asyncio_taskpool import session
+from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_PARSER_WRITER
+from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
+from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool
+
+
+FOO, BAR = 'foo', 'bar'
+
+
+class ControlServerTestCase(IsolatedAsyncioTestCase):
+ log_lvl: int
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls.log_lvl = session.log.level
+ session.log.setLevel(999)
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ session.log.setLevel(cls.log_lvl)
+
+ def setUp(self) -> None:
+ self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock()))
+ self.mock_client_class_name = FOO + BAR
+ self.mock_server = MagicMock(pool=self.mock_pool,
+ client_class_name=self.mock_client_class_name)
+ self.mock_reader = MagicMock()
+ self.mock_writer = MagicMock()
+ self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer)
+
+ def test_init(self):
+ self.assertEqual(self.mock_server, self.session._control_server)
+ self.assertEqual(self.mock_pool, self.session._pool)
+ self.assertEqual(self.mock_client_class_name, self.session._client_class_name)
+ self.assertEqual(self.mock_reader, self.session._reader)
+ self.assertEqual(self.mock_writer, self.session._writer)
+ self.assertIsNone(self.session._parser)
+ self.assertIsNone(self.session._subparsers)
+
+ def test__add_command(self):
+ expected_output = 123456
+ mock_add_parser = MagicMock(return_value=expected_output)
+ self.session._subparsers = MagicMock(add_parser=mock_add_parser)
+ self.session._parser = MagicMock()
+ name, prog, short_help, long_help = 'abc', None, 'short123', None
+ kwargs = {'x': 1, 'y': 2}
+ output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
+ self.assertEqual(expected_output, output)
+ mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help,
+ parent=self.session._parser, **kwargs)
+
+ mock_add_parser.reset_mock()
+
+ prog, long_help = 'ffffff', 'so long, wow'
+ output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
+ self.assertEqual(expected_output, output)
+ mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help,
+ parent=self.session._parser, **kwargs)
+
+ mock_add_parser.reset_mock()
+
+ short_help = None
+ output = self.session._add_command(name, prog, short_help, long_help, **kwargs)
+ self.assertEqual(expected_output, output)
+ mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help,
+ parent=self.session._parser, **kwargs)
+
+ @patch.object(session, 'get_first_doc_line')
+ @patch.object(session.ControlSession, '_add_command')
+ def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock):
+ self.assertIsNone(self.session._add_base_commands())
+ mock__add_command.assert_called()
+ mock_get_first_doc_line.assert_called()
+
+ mock__add_command.reset_mock()
+ mock_get_first_doc_line.reset_mock()
+
+ self.assertIsNone(self.session._add_simple_commands())
+ mock__add_command.assert_called()
+ mock_get_first_doc_line.assert_called()
+
+ with self.assertRaises(NotImplementedError):
+ self.session._add_advanced_commands()
+
+ @patch.object(session.ControlSession, '_add_simple_commands')
+ @patch.object(session.ControlSession, '_add_advanced_commands')
+ @patch.object(session.ControlSession, '_add_base_commands')
+ @patch.object(session, 'CommandParser')
+ def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock,
+ mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock):
+ mock_command_parser_cls.return_value = mock_parser = MagicMock()
+ self.session._pool = TaskPool()
+ width = 1234
+ expected_parser_kwargs = {
+ 'prog': '',
+ SESSION_PARSER_WRITER: self.mock_writer,
+ CLIENT_INFO.TERMINAL_WIDTH: width,
+ }
+ self.assertIsNone(self.session._init_parser(width))
+ mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
+ mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
+ mock__add_base_commands.assert_called_once_with()
+ mock__add_advanced_commands.assert_called_once_with()
+ mock__add_simple_commands.assert_not_called()
+
+ mock_command_parser_cls.reset_mock()
+ mock_parser.add_subparsers.reset_mock()
+ mock__add_base_commands.reset_mock()
+ mock__add_advanced_commands.reset_mock()
+ mock__add_simple_commands.reset_mock()
+
+ async def fake_coroutine(): pass
+
+ self.session._pool = SimpleTaskPool(fake_coroutine)
+ self.assertIsNone(self.session._init_parser(width))
+ mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
+ mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
+ mock__add_base_commands.assert_called_once_with()
+ mock__add_advanced_commands.assert_not_called()
+ mock__add_simple_commands.assert_called_once_with()
+
+ mock_command_parser_cls.reset_mock()
+ mock_parser.add_subparsers.reset_mock()
+ mock__add_base_commands.reset_mock()
+ mock__add_advanced_commands.reset_mock()
+ mock__add_simple_commands.reset_mock()
+
+ class FakeTaskPool(BaseTaskPool):
+ pass
+
+ self.session._pool = FakeTaskPool()
+ with self.assertRaises(UnknownTaskPoolClass):
+ self.session._init_parser(width)
+ mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
+ mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
+ mock__add_base_commands.assert_called_once_with()
+ mock__add_advanced_commands.assert_not_called()
+ mock__add_simple_commands.assert_not_called()
+
+ mock_command_parser_cls.reset_mock()
+ mock_parser.add_subparsers.reset_mock()
+ mock__add_base_commands.reset_mock()
+ mock__add_advanced_commands.reset_mock()
+ mock__add_simple_commands.reset_mock()
+
+ self.session._pool = MagicMock()
+ with self.assertRaises(NotATaskPool):
+ self.session._init_parser(width)
+ mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs)
+ mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD)
+ mock__add_base_commands.assert_called_once_with()
+ mock__add_advanced_commands.assert_not_called()
+ mock__add_simple_commands.assert_not_called()
+
+ @patch.object(session.ControlSession, '_init_parser')
+ async def test_client_handshake(self, mock__init_parser: MagicMock):
+ width = 5678
+ msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
+ mock_read = AsyncMock(return_value=msg.encode())
+ self.mock_reader.read = mock_read
+ self.mock_writer.drain = AsyncMock()
+ self.assertIsNone(await self.session.client_handshake())
+ mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
+ mock__init_parser.assert_called_once_with(width)
+ self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
+ self.mock_writer.drain.assert_awaited_once_with()
+
+ @patch.object(session, 'return_or_exception')
+ async def test__write_function_output(self, mock_return_or_exception: MagicMock):
+ self.mock_writer.write = MagicMock()
+ mock_return_or_exception.return_value = None
+ func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'}
+ self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
+ mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
+ self.mock_writer.write.assert_called_once_with(b"ok")
+
+ mock_return_or_exception.reset_mock()
+ self.mock_writer.write.reset_mock()
+
+ mock_return_or_exception.return_value = output = MagicMock()
+ self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs))
+ mock_return_or_exception.assert_called_once_with(func, *args, **kwargs)
+ self.mock_writer.write.assert_called_once_with(str(output).encode())
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_name(self, mock__write_function_output: AsyncMock):
+ self.assertIsNone(await self.session._cmd_name())
+ mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool)
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock):
+ num = 12345
+ kwargs = {session.NUM: num, FOO: BAR}
+ self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
+ mock__write_function_output.assert_awaited_once_with(
+ self.mock_pool.__class__.pool_size.fset, self.session._pool, num
+ )
+
+ mock__write_function_output.reset_mock()
+
+ kwargs.pop(session.NUM)
+ self.assertIsNone(await self.session._cmd_pool_size(**kwargs))
+ mock__write_function_output.assert_awaited_once_with(
+ self.mock_pool.__class__.pool_size.fget, self.session._pool
+ )
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_num_running(self, mock__write_function_output: AsyncMock):
+ self.assertIsNone(await self.session._cmd_num_running())
+ mock__write_function_output.assert_awaited_once_with(
+ self.mock_pool.__class__.num_running.fget, self.session._pool
+ )
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_start(self, mock__write_function_output: AsyncMock):
+ num = 12345
+ kwargs = {session.NUM: num, FOO: BAR}
+ self.assertIsNone(await self.session._cmd_start(**kwargs))
+ mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num)
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_stop(self, mock__write_function_output: AsyncMock):
+ num = 12345
+ kwargs = {session.NUM: num, FOO: BAR}
+ self.assertIsNone(await self.session._cmd_stop(**kwargs))
+ mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num)
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock):
+ self.assertIsNone(await self.session._cmd_stop_all())
+ mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all)
+
+ @patch.object(session.ControlSession, '_write_function_output')
+ async def test__cmd_func_name(self, mock__write_function_output: AsyncMock):
+ self.assertIsNone(await self.session._cmd_func_name())
+ mock__write_function_output.assert_awaited_once_with(
+ self.mock_pool.__class__.func_name.fget, self.session._pool
+ )
+
+ async def test__execute_command(self):
+ mock_method = AsyncMock()
+ cmd = 'this-is-a-test'
+ setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method)
+ kwargs = {FOO: BAR, 'hello': 'python'}
+ self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs))
+ mock_method.assert_awaited_once_with(**kwargs)
+
+ @patch.object(session.ControlSession, '_execute_command')
+ async def test__parse_command(self, mock__execute_command: AsyncMock):
+ msg = 'asdf asd as a'
+ kwargs = {FOO: BAR, 'hello': 'python'}
+ mock_parse_args = MagicMock(return_value=Namespace(**kwargs))
+ self.session._parser = MagicMock(parse_args=mock_parse_args)
+ self.mock_writer.write = MagicMock()
+ self.assertIsNone(await self.session._parse_command(msg))
+ mock_parse_args.assert_called_once_with(msg.split(' '))
+ self.mock_writer.write.assert_not_called()
+ mock__execute_command.assert_awaited_once_with(**kwargs)
+
+ mock__execute_command.reset_mock()
+ mock_parse_args.reset_mock()
+
+ mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
+ self.assertIsNone(await self.session._parse_command(msg))
+ mock_parse_args.assert_called_once_with(msg.split(' '))
+ self.mock_writer.write.assert_called_once_with(str(exc).encode())
+ mock__execute_command.assert_not_awaited()
+
+ self.mock_writer.write.reset_mock()
+ mock_parse_args.reset_mock()
+
+ mock_parse_args.side_effect = HelpRequested()
+ self.assertIsNone(await self.session._parse_command(msg))
+ mock_parse_args.assert_called_once_with(msg.split(' '))
+ self.mock_writer.write.assert_not_called()
+ mock__execute_command.assert_not_awaited()
+
+ @patch.object(session.ControlSession, '_parse_command')
+ async def test_listen(self, mock__parse_command: AsyncMock):
+ def make_reader_return_empty():
+ self.mock_reader.read.return_value = b''
+ self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
+ msg = "fascinating"
+ self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
+ self.assertIsNone(await self.session.listen())
+ self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
+ mock__parse_command.assert_awaited_once_with(msg)
+ self.mock_writer.drain.assert_awaited_once_with()
+
+ self.mock_reader.read.reset_mock()
+ mock__parse_command.reset_mock()
+ self.mock_writer.drain.reset_mock()
+
+ self.mock_server.is_serving = MagicMock(return_value=False)
+ self.assertIsNone(await self.session.listen())
+ self.mock_reader.read.assert_not_awaited()
+ mock__parse_command.assert_not_awaited()
+ self.mock_writer.drain.assert_not_awaited()