Compare commits

...

2 Commits

Author SHA1 Message Date
b6aed727e9 additional unit tests 2022-02-13 19:55:27 +01:00
c9a8d9ecd1 better placeholders 2022-02-13 19:45:53 +01:00
4 changed files with 51 additions and 5 deletions

View File

@ -55,5 +55,13 @@ class ServerException(Exception):
pass pass
class UnknownTaskPoolClass(ServerException):
pass
class NotATaskPool(ServerException):
pass
class HelpRequested(ServerException): class HelpRequested(ServerException):
pass pass

View File

@ -19,7 +19,6 @@ Miscellaneous helper functions.
""" """
import re
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue from asyncio.queues import Queue
from inspect import getdoc from inspect import getdoc
@ -57,7 +56,7 @@ def tasks_str(num: int) -> str:
def get_first_doc_line(obj: object) -> str: def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0] return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]: async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:

View File

@ -5,9 +5,9 @@ from asyncio.streams import StreamReader, StreamWriter
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO from .constants import CMD, SESSION_PARSER_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
from .exceptions import HelpRequested from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
from .helpers import get_first_doc_line, return_or_exception, tasks_str from .helpers import get_first_doc_line, return_or_exception, tasks_str
from .pool import TaskPool, SimpleTaskPool from .pool import BaseTaskPool, TaskPool, SimpleTaskPool
from .session_parser import CommandParser, NUM from .session_parser import CommandParser, NUM
if TYPE_CHECKING: if TYPE_CHECKING:
@ -67,6 +67,9 @@ class ControlSession:
CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget) CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)
) )
def _add_advanced_commands(self) -> None:
raise NotImplementedError
def _init_parser(self, client_terminal_width: int) -> None: def _init_parser(self, client_terminal_width: int) -> None:
parser_kwargs = { parser_kwargs = {
'prog': '', 'prog': '',
@ -76,9 +79,13 @@ class ControlSession:
self._parser = CommandParser(**parser_kwargs) self._parser = CommandParser(**parser_kwargs)
self._add_base_commands() self._add_base_commands()
if isinstance(self._pool, TaskPool): if isinstance(self._pool, TaskPool):
pass # TODO self._add_advanced_commands()
elif isinstance(self._pool, SimpleTaskPool): elif isinstance(self._pool, SimpleTaskPool):
self._add_simple_commands() self._add_simple_commands()
elif isinstance(self._pool, BaseTaskPool):
raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}")
else:
raise NotATaskPool(f"Not a task pool instance: {self._pool}")
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())

View File

@ -94,3 +94,35 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
self.assertEqual("tasks", helpers.tasks_str(2)) self.assertEqual("tasks", helpers.tasks_str(2))
self.assertEqual("tasks", helpers.tasks_str(-10)) self.assertEqual("tasks", helpers.tasks_str(-10))
self.assertEqual("tasks", helpers.tasks_str(42)) self.assertEqual("tasks", helpers.tasks_str(42))
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}
something else
even more
""")
output = helpers.get_first_doc_line(mock_obj)
self.assertEqual(expected_output, output)
async def test_return_or_exception(self):
expected_output = '420'
mock_func = AsyncMock(return_value=expected_output)
args = (1, 3, 5)
kwargs = {'a': 1, 'b': 2, 'c': 'foo'}
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_awaited_once_with(*args, **kwargs)
mock_func = MagicMock(return_value=expected_output)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
class TestException(Exception):
pass
test_exception = TestException()
mock_func = MagicMock(side_effect=test_exception)
output = await helpers.return_or_exception(mock_func, *args, **kwargs)
self.assertEqual(test_exception, output)
mock_func.assert_called_once_with(*args, **kwargs)