massive overhaul of the control server interface;

making use of `ArgumentParser` for client commands;
new `ControlSession` object instantiated upon connection
This commit is contained in:
2022-02-12 22:51:52 +01:00
parent 024e5db0d4
commit be03097bf4
10 changed files with 300 additions and 78 deletions

View File

@ -26,32 +26,23 @@ from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Tuple, Union, Optional
from typing import Optional, Union
from . import constants
from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession
log = logging.getLogger(__name__)
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
cmd = msg.strip().split(' ', 1)
if len(cmd) > 1:
try:
return cmd[0], int(cmd[1])
except ValueError:
return None, None
return cmd[0], None
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient
_client_class = ControlClient
@classmethod
@property
def client_class_name(cls) -> str:
return cls._client_class.__name__
@abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
@ -61,63 +52,25 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
def final_callback(self) -> None:
raise NotImplementedError
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._pool: SimpleTaskPool = pool
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
self._pool: Union[TaskPool, SimpleTaskPool] = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(await self._pool.start(num)).encode())
def __str__(self) -> str:
return f"{self.__class__.__name__} for {self._pool}"
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
# the requested number may be greater than the total number of running tasks
writer.write(str(self._pool.stop(num)).encode())
@property
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
return self._pool
def _stop_all_tasks(self, writer: StreamWriter) -> None:
log.debug("%s requests stopping all tasks", self.client_class.__name__)
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests number of running tasks", self.client_class.__name__)
writer.write(str(self._pool.num_running).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
writer.write(self._pool.func_name.encode())
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
while self._server.is_serving():
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self.client_class.__name__)
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_NUM_RUNNING:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)
else:
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
writer.write(b"Invalid command!")
await writer.drain()
def is_serving(self) -> bool:
return self._server.is_serving()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__)
writer.write(str(self._pool).encode())
await writer.drain()
await self._listen(reader, writer)
session = ControlSession(self, reader, writer)
await session.client_handshake()
await session.listen()
async def _serve_forever(self) -> None:
try:
@ -135,7 +88,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
class UnixControlServer(ControlServer):
client_class = UnixControlClient
_client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path'))