Compare commits

...

2 Commits

20 changed files with 139 additions and 83 deletions

View File

@ -1,8 +1,7 @@
[run] [run]
source = src/ source = src/
branch = true branch = true
omit = command_line = -m unittest discover
.venv/*
[report] [report]
fail_under = 100 fail_under = 100
@ -11,5 +10,4 @@ skip_covered = False
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == ['"]__main__['"]: if __name__ == ['"]__main__['"]:
omit = if sys.version_info.+:
tests/*

View File

@ -7,6 +7,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- '3.8'
- '3.9' - '3.9'
- '3.10' - '3.10'
steps: steps:

View File

@ -2,7 +2,7 @@ version: 2
build: build:
os: 'ubuntu-20.04' os: 'ubuntu-20.04'
tools: tools:
python: '3.9' python: '3.8'
python: python:
install: install:
- method: pip - method: pip

View File

@ -74,7 +74,7 @@ pip install asyncio-taskpool
## Dependencies ## Dependencies
Python Version 3.9+, tested on Linux Python Version 3.8+, tested on Linux
## Testing ## Testing

View File

@ -12,13 +12,14 @@
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. # You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>. # If not, see <https://www.gnu.org/licenses/>.
coverage erase && coverage run -m unittest discover 2> /dev/null coverage erase
coverage run 2> /dev/null
typeset total typeset report=$(coverage report)
total=$(coverage report | awk '$1 == "TOTAL" {print $NF}') typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
if [[ $total == 100% ]]; then if [[ ${total} == 100% ]]; then
echo $total echo ${total}
else else
coverage report echo "${report}"
fi fi

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.parser module
=======================================
.. automodule:: asyncio_taskpool.control.parser
:members:
:undoc-members:
:show-inheritance:

View File

@ -13,6 +13,4 @@ Submodules
:maxdepth: 4 :maxdepth: 4
asyncio_taskpool.control.client asyncio_taskpool.control.client
asyncio_taskpool.control.parser
asyncio_taskpool.control.server asyncio_taskpool.control.server
asyncio_taskpool.control.session

View File

@ -1,7 +0,0 @@
asyncio\_taskpool.control.session module
========================================
.. automodule:: asyncio_taskpool.control.session
:members:
:undoc-members:
:show-inheritance:

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 1.0.0 version = 1.0.1
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks
@ -25,7 +25,7 @@ classifiers =
package_dir = package_dir =
= src = src
packages = find: packages = find:
python_requires = >=3.9 python_requires = >=3.8
[options.extras_require] [options.extras_require]
dev = dev =

View File

@ -97,21 +97,22 @@ class ControlClient(ABC):
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns: Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect; `None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase. client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
converted to lowercase.
""" """
try: try:
msg = input("> ").strip().lower() cmd = input("> ").strip().lower()
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command. except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
msg = CLIENT_EXIT cmd = CLIENT_EXIT
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt. except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print() print()
return return
if msg == CLIENT_EXIT: if cmd == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
return msg return cmd or None # will be None if `cmd` is an empty string
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
""" """

View File

@ -17,19 +17,21 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Definition of the :class:`ControlParser` used in a Definition of the :class:`ControlParser` used in a
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`. :class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
It should not be considered part of the public API.
""" """
import logging import logging
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
from ast import literal_eval from ast import literal_eval
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature from inspect import Parameter, getmembers, isfunction, signature
from io import StringIO
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
from ..exceptions import HelpRequested, ParserError from ..exceptions import HelpRequested, ParserError
from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER from ..internals.constants import CLIENT_INFO, CMD
from ..internals.helpers import get_first_doc_line, resolve_dotted_path from ..internals.helpers import get_first_doc_line, resolve_dotted_path
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
@ -52,8 +54,8 @@ class ControlParser(ArgumentParser):
""" """
Subclass of the standard :code:`argparse.ArgumentParser` for pool control. Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
instance passed to it during initialization. `StringIO` instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client. connected client.
Finally, it offers some convenience methods and makes use of custom exceptions. Finally, it offers some convenience methods and makes use of custom exceptions.
@ -87,25 +89,23 @@ class ControlParser(ArgumentParser):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
return ClientHelpFormatter return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None: def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
""" """
Sets some internal attributes in addition to the base class. Sets some internal attributes in addition to the base class.
Args: Args:
stream_writer: stream:
The instance of the :class:`asyncio.StreamWriter` to use for message output. A file-like I/O object to use for message output.
terminal_width (optional): terminal_width (optional):
The terminal width to use for all message formatting. By default the :code:`columns` attribute from The terminal width to use for all message formatting. By default the :code:`columns` attribute from
:func:`shutil.get_terminal_size` is taken. :func:`shutil.get_terminal_size` is taken.
**kwargs(optional): **kwargs(optional):
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`. class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
""" """
self._stream_writer: StreamWriter = stream_writer self._stream: StringIO = stream
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class')) kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs) super().__init__(**kwargs)
self._flags: Set[str] = set() self._flags: Set[str] = set()
self._commands = None self._commands = None
@ -194,7 +194,7 @@ class ControlParser(ArgumentParser):
Dictionary mapping class member names to the (sub-)parsers created from them. Dictionary mapping class member names to the (sub-)parsers created from them.
""" """
parsers: ParsersDict = {} parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width} common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls): for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only): if name in omit_members or (name.startswith('_') and public_only):
continue continue
@ -214,9 +214,9 @@ class ControlParser(ArgumentParser):
return self._commands return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None: def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
if message: if message:
self._stream_writer.write(message.encode()) self._stream.write(message)
def exit(self, status: int = 0, message: str = None) -> None: def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked.""" """This is overridden to prevent system exit to be invoked."""

View File

@ -31,6 +31,7 @@ from typing import Optional, Union
from .client import ControlClient, TCPControlClient, UnixControlClient from .client import ControlClient, TCPControlClient, UnixControlClient
from .session import ControlSession from .session import ControlSession
from ..pool import AnyTaskPoolT from ..pool import AnyTaskPoolT
from ..internals.helpers import classmethod
from ..internals.types import ConnectedCallbackT, PathT from ..internals.types import ConnectedCallbackT, PathT

View File

@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Definition of the :class:`ControlSession` used by a :class:`ControlServer`. Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
It should not be considered part of the public API.
""" """
@ -24,12 +26,13 @@ import json
from argparse import ArgumentError from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature from inspect import isfunction, signature
from io import StringIO
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from .parser import ControlParser from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES
from ..internals.helpers import return_or_exception from ..internals.helpers import return_or_exception
if TYPE_CHECKING: if TYPE_CHECKING:
@ -72,6 +75,7 @@ class ControlSession:
self._reader: StreamReader = reader self._reader: StreamReader = reader
self._writer: StreamWriter = writer self._writer: StreamWriter = writer
self._parser: Optional[ControlParser] = None self._parser: Optional[ControlParser] = None
self._response_buffer: StringIO = StringIO()
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None: async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
""" """
@ -133,7 +137,7 @@ class ControlSession:
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
parser_kwargs = { parser_kwargs = {
STREAM_WRITER: self._writer, 'stream': self._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH], CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -160,7 +164,7 @@ class ControlSession:
kwargs = vars(self._parser.parse_args(msg.split(' '))) kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e: except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name) log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode()) self._response_buffer.write(str(e))
return return
except (HelpRequested, ParserError): except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name) log.debug("%s received usage help", self._client_class_name)
@ -171,7 +175,7 @@ class ControlSession:
elif isinstance(command, property): elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs) await self._exec_property_and_respond(command, **kwargs)
else: else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode()) self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
async def listen(self) -> None: async def listen(self) -> None:
""" """
@ -188,4 +192,8 @@ class ControlSession:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break
await self._parse_command(msg) await self._parse_command(msg)
response = self._response_buffer.getvalue()
self._response_buffer.seek(0)
self._response_buffer.truncate()
self._writer.write(response.encode())
await self._writer.drain() await self._writer.drain()

View File

@ -27,7 +27,6 @@ DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer'
CMD = 'command' CMD = 'command'
CMD_OK = b"ok" CMD_OK = b"ok"

View File

@ -19,10 +19,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
""" """
import builtins
import sys
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Type, Union
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -131,3 +133,25 @@ def resolve_dotted_path(dotted_path: str) -> object:
import_module(module_name) import_module(module_name)
found = getattr(found, name) found = getattr(found, name)
return found return found
class ClassMethodWorkaround:
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
def __init__(self, method_or_property: Union[Callable, property]) -> None:
if isinstance(method_or_property, property):
self._getter = method_or_property.fget
else:
self._getter = method_or_property
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
if obj is None:
return self._getter(cls)
return self._getter(obj)
# Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9):
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
mock_client = MagicMock(start=mock_client_start) mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client) mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789} mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
self.assertIsNone(await module.main()) self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with() mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs) mock_client_cls.assert_called_once_with(**mock_client_kwargs)

View File

@ -41,9 +41,9 @@ class ControlParserTestCase(TestCase):
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420 self.stream, self.terminal_width = MagicMock(), 420
self.kwargs = { self.kwargs = {
'stream_writer': self.stream_writer, 'stream': self.stream,
'terminal_width': self.terminal_width, 'terminal_width': self.terminal_width,
'formatter_class': FOO 'formatter_class': FOO
} }
@ -72,10 +72,9 @@ class ControlParserTestCase(TestCase):
def test_init(self): def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser) self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer) self.assertEqual(self.stream, self.parser._stream)
self.assertEqual(self.terminal_width, self.parser._terminal_width) self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags) self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands) self.assertIsNone(self.parser._commands)
@ -89,7 +88,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar' expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
to_omit = ['abc', 'xyz'] to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs) output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
@ -107,7 +106,7 @@ class ControlParserTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop' expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop) mock_get_first_doc_line.assert_called_once_with(get_prop)
@ -119,7 +118,7 @@ class ControlParserTestCase(TestCase):
prop = property(get_prop, set_prop) prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property" expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)]) mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
@ -152,8 +151,7 @@ class ControlParserTestCase(TestCase):
mock_subparser = MagicMock(set_defaults=mock_set_defaults) mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x' x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer, common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser} expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x) output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
@ -170,12 +168,12 @@ class ControlParserTestCase(TestCase):
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs) mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self): def test__print_message(self):
self.stream_writer.write = MagicMock() self.stream.write = MagicMock()
self.assertIsNone(self.parser._print_message('')) self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called() self.stream.write.assert_not_called()
msg = 'foo bar baz' msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg)) self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode()) self.stream.write.assert_called_once_with(msg)
@patch.object(parser.ControlParser, '_print_message') @patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock): def test_exit(self, mock__print_message: MagicMock):

View File

@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
import json import json
from argparse import ArgumentError, Namespace from argparse import ArgumentError, Namespace
from io import StringIO
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session from asyncio_taskpool.control import session
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES
from asyncio_taskpool.exceptions import HelpRequested from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool from asyncio_taskpool.pool import SimpleTaskPool
@ -61,14 +62,15 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.mock_reader, self.session._reader) self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer) self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser) self.assertIsNone(self.session._parser)
self.assertIsInstance(self.session._response_buffer, StringIO)
@patch.object(session, 'return_or_exception') @patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock): async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11} test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
mock_return_or_exception.return_value = None mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs)) self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
mock_return_or_exception.assert_awaited_once_with( mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
) )
@ -104,7 +106,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_reader.read = mock_read self.mock_reader.read = mock_read
self.mock_writer.drain = AsyncMock() self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = { expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer, 'stream': self.session._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: width, CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -132,10 +134,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
kwargs = {FOO: BAR, 'hello': 'python'} kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs)) mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args) self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs) mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
@ -145,7 +146,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs) mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs) mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
@ -161,26 +162,28 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_property_and_respond.reset_mock() mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock() self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = HelpRequested() mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
@ -191,17 +194,23 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating" msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO
self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
mock__parse_command.assert_awaited_once_with(msg) mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode())
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock() self.mock_reader.read.reset_mock()
mock__parse_command.reset_mock() mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock() self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False) self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited() self.mock_reader.read.assert_not_awaited()
mock__parse_command.assert_not_awaited() mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited() self.mock_writer.drain.assert_not_awaited()

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.helpers` module.
""" """
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import helpers from asyncio_taskpool.internals import helpers
@ -122,3 +122,33 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
helpers.resolve_dotted_path('foo.bar.baz') helpers.resolve_dotted_path('foo.bar.baz')
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')]) mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
class ClassMethodWorkaroundTestCase(TestCase):
def test_init(self):
def func(): return 'foo'
def getter(): return 'bar'
prop = property(getter)
instance = helpers.ClassMethodWorkaround(func)
self.assertIs(func, instance._getter)
instance = helpers.ClassMethodWorkaround(prop)
self.assertIs(getter, instance._getter)
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
def test_get(self, _mock_init: MagicMock):
def func(x: MagicMock): return x.__name__
instance = helpers.ClassMethodWorkaround(MagicMock())
instance._getter = func
obj, cls = None, MagicMock
expected_output = 'MagicMock'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
obj = MagicMock(__name__='bar')
expected_output = 'bar'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)

View File

@ -729,13 +729,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
TEST_POOL_CANCEL_CB = MagicMock() TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict: def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | { params = super().get_task_pool_init_params()
params.update({
'func': self.TEST_POOL_FUNC, 'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS, 'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS, 'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB, 'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB, 'cancel_callback': self.TEST_POOL_CANCEL_CB,
} })
return params
def setUp(self) -> None: def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__') self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')