Compare commits

..

4 Commits

16 changed files with 203 additions and 51 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.6.0
version = 0.6.2
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -19,5 +19,5 @@ Brings the main classes up to package level for import convenience.
"""
from .control.server import TCPControlServer, UnixControlServer
from .pool import TaskPool, SimpleTaskPool
from .server import TCPControlServer, UnixControlServer

View File

View File

@ -19,37 +19,30 @@ CLI client entry point.
"""
import sys
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Dict, Any
from typing import Any, Dict, Sequence
from ..constants import PACKAGE_NAME
from ..pool import TaskPool
from .client import ControlClient, TCPControlClient, UnixControlClient
from .constants import PACKAGE_NAME
from .pool import TaskPool
from .server import TCPControlServer, UnixControlServer
CONN_TYPE = 'conn_type'
CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port'
def parse_cli() -> Dict[str, Any]:
def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
parser = ArgumentParser(
prog=PACKAGE_NAME,
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
prog=f'{PACKAGE_NAME}.control',
description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types")
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument(
HOST,
@ -60,19 +53,25 @@ def parse_cli() -> Dict[str, Any]:
type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
return vars(parser.parse_args())
tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient})
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient})
return vars(parser.parse_args(args))
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
client_cls = kwargs.pop(CLIENT_CLASS)
await client_cls(**kwargs).start()
if __name__ == '__main__':
run(main())

View File

@ -27,8 +27,8 @@ from asyncio.streams import StreamReader, StreamWriter, open_connection
from pathlib import Path
from typing import Optional, Union
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT, PathT
from ..constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from ..types import ClientConnT, PathT
class ControlClient(ABC):

View File

@ -25,9 +25,9 @@ from inspect import Parameter, getmembers, isfunction, signature
from shutil import get_terminal_size
from typing import Callable, Container, Dict, Set, Type, TypeVar
from .constants import CLIENT_INFO, CMD, STREAM_WRITER
from .exceptions import HelpRequested
from .helpers import get_first_doc_line
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..exceptions import HelpRequested
from ..helpers import get_first_doc_line
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])

View File

@ -28,10 +28,10 @@ from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Optional, Union
from ..pool import TaskPool, SimpleTaskPool
from ..types import ConnectedCallbackT
from .client import ControlClient, TCPControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession
from .types import ConnectedCallbackT
log = logging.getLogger(__name__)

View File

@ -26,10 +26,10 @@ from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature
from typing import Callable, Optional, Union, TYPE_CHECKING
from .constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from .exceptions import HelpRequested
from .helpers import return_or_exception
from .pool import TaskPool, SimpleTaskPool
from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from ..exceptions import CommandError, HelpRequested
from ..helpers import return_or_exception
from ..pool import TaskPool, SimpleTaskPool
from .parser import ControlParser
if TYPE_CHECKING:
@ -131,7 +131,7 @@ class ControlSession:
STREAM_WRITER: self._writer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '',
'usage': f'%(prog)s [-h] [{CMD}] ...'
'usage': f'[-h] [{CMD}] ...'
}
self._parser = ControlParser(**parser_kwargs)
self._parser.add_subparsers(title="Commands",
@ -163,6 +163,8 @@ class ControlSession:
await self._exec_method_and_respond(command, **kwargs)
elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs)
else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode())
async def listen(self) -> None:
"""

View File

@ -65,3 +65,7 @@ class ServerException(Exception):
class HelpRequested(ServerException):
pass
class CommandError(ServerException):
pass

View File

@ -120,7 +120,7 @@ class BaseTaskPool:
@property
def is_locked(self) -> bool:
"""Returns `True` if more the pool has been locked (see below)."""
"""Returns `True` if the pool has been locked (see below)."""
return self._locked
def lock(self) -> None:

View File

View File

@ -0,0 +1,45 @@
from pathlib import Path
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient
from asyncio_taskpool.control import __main__ as module
class CLITestCase(IsolatedAsyncioTestCase):
def test_parse_cli(self):
socket_path = '/some/path/to.sock'
args = [module.UNIX, socket_path]
expected_kwargs = {
module.CLIENT_CLASS: UnixControlClient,
module.SOCKET_PATH: Path(socket_path)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
host, port = '1.2.3.4', '1234'
args = [module.TCP, host, port]
expected_kwargs = {
module.CLIENT_CLASS: TCPControlClient,
module.HOST: host,
module.PORT: int(port)
}
parsed_kwargs = module.parse_cli(args)
self.assertDictEqual(expected_kwargs, parsed_kwargs)
with patch('sys.stderr'):
with self.assertRaises(SystemExit):
module.parse_cli(['invalid', 'foo', 'bar'])
@patch.object(module, 'parse_cli')
async def test_main(self, mock_parse_cli: MagicMock):
mock_client_start = AsyncMock()
mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
mock_client_start.assert_awaited_once_with()

View File

@ -27,7 +27,7 @@ from pathlib import Path
from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import client
from asyncio_taskpool.control import client
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
@ -37,7 +37,7 @@ FOO, BAR = 'foo', 'bar'
class ControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set())
self.abstract_patcher = patch('asyncio_taskpool.control.client.ControlClient.__abstractmethods__', set())
self.print_patcher = patch.object(client, 'print')
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_print = self.print_patcher.start()
@ -172,6 +172,43 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.mock_print.assert_called_once_with("Disconnected from control server.")
class TCPControlClientTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.base_init_patcher = patch.object(client.ControlClient, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.host, self.port = 'localhost', 12345
self.kwargs = {FOO: 123, BAR: 456}
self.client = client.TCPControlClient(host=self.host, port=self.port, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test_init(self):
self.assertEqual(self.host, self.client._host)
self.assertEqual(self.port, self.client._port)
self.mock_base_init.assert_called_once_with(**self.kwargs)
@patch.object(client, 'print')
@patch.object(client, 'open_connection')
async def test__open_connection(self, mock_open_connection: AsyncMock, mock_print: MagicMock):
mock_open_connection.return_value = expected_output = 'something'
kwargs = {'a': 1, 'b': 2}
output = await self.client._open_connection(**kwargs)
self.assertEqual(expected_output, output)
mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs)
mock_print.assert_not_called()
mock_open_connection.reset_mock()
mock_open_connection.side_effect = e = ConnectionError()
output1, output2 = await self.client._open_connection(**kwargs)
self.assertIsNone(output1)
self.assertIsNone(output2)
mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs)
mock_print.assert_called_once_with(str(e), file=sys.stderr)
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
class UnixControlClientTestCase(IsolatedAsyncioTestCase):

View File

@ -24,7 +24,7 @@ from inspect import signature
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
from asyncio_taskpool import parser
from asyncio_taskpool.control import parser
from asyncio_taskpool.exceptions import HelpRequested
@ -157,6 +157,14 @@ class ControlServerTestCase(TestCase):
mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs)
mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})])
@patch.object(parser.ArgumentParser, 'add_subparsers')
def test_add_subparsers(self, mock_base_add_subparsers: MagicMock):
args, kwargs = [1, 2, 42], {FOO: 123, BAR: 456}
mock_base_add_subparsers.return_value = mock_action = MagicMock()
output = self.parser.add_subparsers(*args, **kwargs)
self.assertEqual(mock_action, output)
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self):
self.stream_writer.write = MagicMock()
self.assertIsNone(self.parser._print_message(''))

View File

@ -26,8 +26,8 @@ from pathlib import Path
from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, patch
from asyncio_taskpool import server
from asyncio_taskpool.client import ControlClient, UnixControlClient
from asyncio_taskpool.control import server
from asyncio_taskpool.control.client import ControlClient, TCPControlClient, UnixControlClient
FOO, BAR = 'foo', 'bar'
@ -46,7 +46,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set())
self.abstract_patcher = patch('asyncio_taskpool.control.server.ControlServer.__abstractmethods__', set())
self.mock_abstract_methods = self.abstract_patcher.start()
self.mock_pool = MagicMock()
self.kwargs = {FOO: 123, BAR: 456}
@ -120,6 +120,50 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_create_task.assert_called_once_with(mock_awaitable)
class TCPControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = server.log.level
server.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
server.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self.base_init_patcher = patch.object(server.ControlServer, '__init__')
self.mock_base_init = self.base_init_patcher.start()
self.mock_pool = MagicMock()
self.host, self.port = 'localhost', 12345
self.kwargs = {FOO: 123, BAR: 456}
self.server = server.TCPControlServer(pool=self.mock_pool, host=self.host, port=self.port, **self.kwargs)
def tearDown(self) -> None:
self.base_init_patcher.stop()
def test__client_class(self):
self.assertEqual(TCPControlClient, self.server._client_class)
def test_init(self):
self.assertEqual(self.host, self.server._host)
self.assertEqual(self.port, self.server._port)
self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs)
@patch.object(server, 'start_server')
async def test__get_server_instance(self, mock_start_server: AsyncMock):
mock_start_server.return_value = expected_output = 'totally_a_server'
mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2}
args = [mock_callback]
output = await self.server._get_server_instance(*args, **mock_kwargs)
self.assertEqual(expected_output, output)
mock_start_server.assert_called_once_with(mock_callback, self.host, self.port, **mock_kwargs)
def test__final_callback(self):
self.assertIsNone(self.server._final_callback())
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
log_lvl: int

View File

@ -24,7 +24,7 @@ from argparse import ArgumentError, Namespace
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool import session
from asyncio_taskpool.control import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool
@ -107,7 +107,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
STREAM_WRITER: self.mock_writer,
CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '',
'usage': f'%(prog)s [-h] [{CMD}] ...'
'usage': f'[-h] [{CMD}] ...'
}
expected_subparsers_kwargs = {
'title': "Commands",
@ -142,9 +142,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock__exec_method_and_respond.reset_mock()
mock_parse_args.reset_mock()
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: prop}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called()
@ -154,6 +152,21 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
bad_command = 'definitely not a function or property'
mock_parse_args.return_value = Namespace(**{CMD: bad_command}, **kwargs)
with patch.object(session, 'CommandError') as cmd_err_cls:
cmd_err_cls.return_value = exc = MagicMock()
self.assertIsNone(await self.session._parse_command(msg))
cmd_err_cls.assert_called_once_with(f"Unknown command object: {bad_command}")
mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode())
mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' '))