From 6a5c200ae63ac8ef3649a98e89ab392db6266208 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Fri, 4 Feb 2022 16:25:09 +0100 Subject: [PATCH] working server and client --- src/asyncio_taskpool/__init__.py | 1 + src/asyncio_taskpool/client.py | 56 +++++++++++++++ src/asyncio_taskpool/constants.py | 6 ++ src/asyncio_taskpool/pool.py | 20 ++++-- src/asyncio_taskpool/server.py | 115 ++++++++++++++++++++++++++++++ src/asyncio_taskpool/types.py | 5 +- 6 files changed, 197 insertions(+), 6 deletions(-) create mode 100644 src/asyncio_taskpool/client.py create mode 100644 src/asyncio_taskpool/constants.py create mode 100644 src/asyncio_taskpool/server.py diff --git a/src/asyncio_taskpool/__init__.py b/src/asyncio_taskpool/__init__.py index 4a64d4c..d2f82a0 100644 --- a/src/asyncio_taskpool/__init__.py +++ b/src/asyncio_taskpool/__init__.py @@ -1 +1,2 @@ from .pool import TaskPool +from .server import UnixControlServer diff --git a/src/asyncio_taskpool/client.py b/src/asyncio_taskpool/client.py new file mode 100644 index 0000000..e63d52d --- /dev/null +++ b/src/asyncio_taskpool/client.py @@ -0,0 +1,56 @@ +import sys +from abc import ABC, abstractmethod +from asyncio.streams import StreamReader, StreamWriter, open_unix_connection +from pathlib import Path + +from asyncio_taskpool import constants +from asyncio_taskpool.types import ClientConnT + + +class ControlClient(ABC): + + @abstractmethod + async def open_connection(self, **kwargs) -> ClientConnT: + raise NotImplementedError + + def __init__(self, **conn_kwargs) -> None: + self._conn_kwargs = conn_kwargs + self._connected: bool = False + + async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: + try: + msg = input("> ").strip().lower() + except EOFError: + msg = constants.CLIENT_EXIT + if msg == constants.CLIENT_EXIT: + writer.close() + self._connected = False + return + writer.write(msg.encode()) + await writer.drain() + print("Command sent; awaiting response...") + print("Server response:", (await reader.read(constants.MSG_BYTES)).decode()) + + async def start(self): + reader, writer = await self.open_connection(**self._conn_kwargs) + if reader is None: + print("Failed to connect.", file=sys.stderr) + return + self._connected = True + print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) + while self._connected: + await self._interact(reader, writer) + print("Disconnected from control server.") + + +class UnixControlClient(ControlClient): + def __init__(self, **conn_kwargs) -> None: + self._socket_path = Path(conn_kwargs.pop('path')) + super().__init__(**conn_kwargs) + + async def open_connection(self, **kwargs) -> ClientConnT: + try: + return await open_unix_connection(self._socket_path, **kwargs) + except FileNotFoundError: + print("No socket at", self._socket_path, file=sys.stderr) + return None, None diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py new file mode 100644 index 0000000..954c692 --- /dev/null +++ b/src/asyncio_taskpool/constants.py @@ -0,0 +1,6 @@ +PACKAGE_NAME = 'asyncio_taskpool' +MSG_BYTES = 1024 +CMD_START = 'start' +CMD_STOP = 'stop' +CMD_STOP_ALL = 'stop_all' +CLIENT_EXIT = 'exit' diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index b0bc38e..87f01e1 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -19,6 +19,7 @@ class TaskPool: self._final_callback: FinalCallbackT = final_callback self._cancel_callback: CancelCallbackT = cancel_callback self._tasks: List[Task] = [] + self._cancelled: List[Task] = [] @property def func_name(self) -> str: @@ -43,11 +44,20 @@ class TaskPool: self._start_one() def stop(self, num: int = 1) -> int: - if num < 1: - return 0 - return sum(task.cancel() for task in reversed(self._tasks[-num:])) + for i in range(num): + try: + task = self._tasks.pop() + except IndexError: + num = i + break + task.cancel() + self._cancelled.append(task) + return num + + def stop_all(self) -> int: + return self.stop(self.size) async def gather(self, return_exceptions: bool = False): - results = await gather(*self._tasks, return_exceptions=return_exceptions) - self._tasks = [] + results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions) + self._tasks = self._cancelled = [] return results diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py new file mode 100644 index 0000000..27231e8 --- /dev/null +++ b/src/asyncio_taskpool/server.py @@ -0,0 +1,115 @@ +import logging +from abc import ABC, abstractmethod +from asyncio import AbstractServer +from asyncio.exceptions import CancelledError +from asyncio.streams import StreamReader, StreamWriter, start_unix_server +from asyncio.tasks import Task, create_task +from pathlib import Path +from typing import Tuple, Union, Optional + +from . import constants +from .pool import TaskPool +from .client import ControlClient, UnixControlClient + + +log = logging.getLogger(__name__) + + +def tasks_str(num: int) -> str: + return "tasks" if num != 1 else "task" + + +def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]: + cmd = msg.strip().split(' ', 1) + if len(cmd) > 1: + try: + return cmd[0], int(cmd[1]) + except ValueError: + return None, None + return cmd[0], None + + +class ControlServer(ABC): + client_class = ControlClient + + @abstractmethod + async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: + raise NotImplementedError + + @abstractmethod + def final_callback(self) -> None: + raise NotImplementedError + + def __init__(self, pool: TaskPool, **server_kwargs) -> None: + self._pool: TaskPool = pool + self._server_kwargs = server_kwargs + self._server: Optional[AbstractServer] = None + + def _start_tasks(self, num: int, writer: StreamWriter) -> None: + log.debug("Client requests starting %s tasks", num) + self._pool.start(num) + size = self._pool.size + writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode()) + + def _stop_tasks(self, num: int, writer: StreamWriter) -> None: + log.debug("Client requests stopping %s tasks", num) + num = self._pool.stop(num) # the requested number may be greater than the total number of running tasks + size = self._pool.size + writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode()) + + def _stop_all_tasks(self, writer: StreamWriter) -> None: + log.debug("Client requests stopping all tasks") + num = self._pool.stop_all() + writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode()) + + async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: + log.debug("%s connected", self.client_class.__name__) + writer.write(f"{self.__class__.__name__} for {self._pool}".encode()) + await writer.drain() + while True: + msg = (await reader.read(constants.MSG_BYTES)).decode().strip() + if not msg: + log.debug("%s disconnected", self.client_class.__name__) + break + cmd, arg = get_cmd_arg(msg) + if cmd == constants.CMD_START: + self._start_tasks(arg, writer) + elif cmd == constants.CMD_STOP: + self._stop_tasks(arg, writer) + elif cmd == constants.CMD_STOP_ALL: + self._stop_all_tasks(writer) + else: + log.debug("%s sent invalid command: %s", self.client_class.__name__, msg) + writer.write(b"Invalid command!") + await writer.drain() + + async def _serve_forever(self) -> None: + try: + async with self._server: + await self._server.serve_forever() + except CancelledError: + log.debug("%s stopped", self.__class__.__name__) + finally: + self.final_callback() + + async def serve_forever(self) -> Task: + log.debug("Starting %s...", self.__class__.__name__) + self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs) + return create_task(self._serve_forever()) + + +class UnixControlServer(ControlServer): + client_class = UnixControlClient + + def __init__(self, pool: TaskPool, **server_kwargs) -> None: + self._socket_path = Path(server_kwargs.pop('path')) + super().__init__(pool, **server_kwargs) + + async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: + srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) + log.debug("Opened socket '%s'", str(self._socket_path)) + return srv + + def final_callback(self) -> None: + self._socket_path.unlink() + log.debug("Removed socket '%s'", str(self._socket_path)) diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index f2c16d4..94a0589 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -1,6 +1,9 @@ -from typing import Callable, Awaitable, Any +from asyncio.streams import StreamReader, StreamWriter +from typing import Tuple, Callable, Awaitable, Union, Any CoroutineFunc = Callable[[...], Awaitable[Any]] FinalCallbackT = Callable CancelCallbackT = Callable + +ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]