generated from daniil-berg/boilerplate-py
working server and client
This commit is contained in:
parent
c9e0e2f255
commit
6a5c200ae6
@ -1 +1,2 @@
|
||||
from .pool import TaskPool
|
||||
from .server import UnixControlServer
|
||||
|
56
src/asyncio_taskpool/client.py
Normal file
56
src/asyncio_taskpool/client.py
Normal file
@ -0,0 +1,56 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||
from pathlib import Path
|
||||
|
||||
from asyncio_taskpool import constants
|
||||
from asyncio_taskpool.types import ClientConnT
|
||||
|
||||
|
||||
class ControlClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
self._conn_kwargs = conn_kwargs
|
||||
self._connected: bool = False
|
||||
|
||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
try:
|
||||
msg = input("> ").strip().lower()
|
||||
except EOFError:
|
||||
msg = constants.CLIENT_EXIT
|
||||
if msg == constants.CLIENT_EXIT:
|
||||
writer.close()
|
||||
self._connected = False
|
||||
return
|
||||
writer.write(msg.encode())
|
||||
await writer.drain()
|
||||
print("Command sent; awaiting response...")
|
||||
print("Server response:", (await reader.read(constants.MSG_BYTES)).decode())
|
||||
|
||||
async def start(self):
|
||||
reader, writer = await self.open_connection(**self._conn_kwargs)
|
||||
if reader is None:
|
||||
print("Failed to connect.", file=sys.stderr)
|
||||
return
|
||||
self._connected = True
|
||||
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
||||
while self._connected:
|
||||
await self._interact(reader, writer)
|
||||
print("Disconnected from control server.")
|
||||
|
||||
|
||||
class UnixControlClient(ControlClient):
|
||||
def __init__(self, **conn_kwargs) -> None:
|
||||
self._socket_path = Path(conn_kwargs.pop('path'))
|
||||
super().__init__(**conn_kwargs)
|
||||
|
||||
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||
try:
|
||||
return await open_unix_connection(self._socket_path, **kwargs)
|
||||
except FileNotFoundError:
|
||||
print("No socket at", self._socket_path, file=sys.stderr)
|
||||
return None, None
|
6
src/asyncio_taskpool/constants.py
Normal file
6
src/asyncio_taskpool/constants.py
Normal file
@ -0,0 +1,6 @@
|
||||
PACKAGE_NAME = 'asyncio_taskpool'
|
||||
MSG_BYTES = 1024
|
||||
CMD_START = 'start'
|
||||
CMD_STOP = 'stop'
|
||||
CMD_STOP_ALL = 'stop_all'
|
||||
CLIENT_EXIT = 'exit'
|
@ -19,6 +19,7 @@ class TaskPool:
|
||||
self._final_callback: FinalCallbackT = final_callback
|
||||
self._cancel_callback: CancelCallbackT = cancel_callback
|
||||
self._tasks: List[Task] = []
|
||||
self._cancelled: List[Task] = []
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
@ -43,11 +44,20 @@ class TaskPool:
|
||||
self._start_one()
|
||||
|
||||
def stop(self, num: int = 1) -> int:
|
||||
if num < 1:
|
||||
return 0
|
||||
return sum(task.cancel() for task in reversed(self._tasks[-num:]))
|
||||
for i in range(num):
|
||||
try:
|
||||
task = self._tasks.pop()
|
||||
except IndexError:
|
||||
num = i
|
||||
break
|
||||
task.cancel()
|
||||
self._cancelled.append(task)
|
||||
return num
|
||||
|
||||
def stop_all(self) -> int:
|
||||
return self.stop(self.size)
|
||||
|
||||
async def gather(self, return_exceptions: bool = False):
|
||||
results = await gather(*self._tasks, return_exceptions=return_exceptions)
|
||||
self._tasks = []
|
||||
results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions)
|
||||
self._tasks = self._cancelled = []
|
||||
return results
|
||||
|
115
src/asyncio_taskpool/server.py
Normal file
115
src/asyncio_taskpool/server.py
Normal file
@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import AbstractServer
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
||||
from asyncio.tasks import Task, create_task
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
from . import constants
|
||||
from .pool import TaskPool
|
||||
from .client import ControlClient, UnixControlClient
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tasks_str(num: int) -> str:
|
||||
return "tasks" if num != 1 else "task"
|
||||
|
||||
|
||||
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
|
||||
cmd = msg.strip().split(' ', 1)
|
||||
if len(cmd) > 1:
|
||||
try:
|
||||
return cmd[0], int(cmd[1])
|
||||
except ValueError:
|
||||
return None, None
|
||||
return cmd[0], None
|
||||
|
||||
|
||||
class ControlServer(ABC):
|
||||
client_class = ControlClient
|
||||
|
||||
@abstractmethod
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def final_callback(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, pool: TaskPool, **server_kwargs) -> None:
|
||||
self._pool: TaskPool = pool
|
||||
self._server_kwargs = server_kwargs
|
||||
self._server: Optional[AbstractServer] = None
|
||||
|
||||
def _start_tasks(self, num: int, writer: StreamWriter) -> None:
|
||||
log.debug("Client requests starting %s tasks", num)
|
||||
self._pool.start(num)
|
||||
size = self._pool.size
|
||||
writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode())
|
||||
|
||||
def _stop_tasks(self, num: int, writer: StreamWriter) -> None:
|
||||
log.debug("Client requests stopping %s tasks", num)
|
||||
num = self._pool.stop(num) # the requested number may be greater than the total number of running tasks
|
||||
size = self._pool.size
|
||||
writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode())
|
||||
|
||||
def _stop_all_tasks(self, writer: StreamWriter) -> None:
|
||||
log.debug("Client requests stopping all tasks")
|
||||
num = self._pool.stop_all()
|
||||
writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode())
|
||||
|
||||
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||
log.debug("%s connected", self.client_class.__name__)
|
||||
writer.write(f"{self.__class__.__name__} for {self._pool}".encode())
|
||||
await writer.drain()
|
||||
while True:
|
||||
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
|
||||
if not msg:
|
||||
log.debug("%s disconnected", self.client_class.__name__)
|
||||
break
|
||||
cmd, arg = get_cmd_arg(msg)
|
||||
if cmd == constants.CMD_START:
|
||||
self._start_tasks(arg, writer)
|
||||
elif cmd == constants.CMD_STOP:
|
||||
self._stop_tasks(arg, writer)
|
||||
elif cmd == constants.CMD_STOP_ALL:
|
||||
self._stop_all_tasks(writer)
|
||||
else:
|
||||
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
|
||||
writer.write(b"Invalid command!")
|
||||
await writer.drain()
|
||||
|
||||
async def _serve_forever(self) -> None:
|
||||
try:
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
except CancelledError:
|
||||
log.debug("%s stopped", self.__class__.__name__)
|
||||
finally:
|
||||
self.final_callback()
|
||||
|
||||
async def serve_forever(self) -> Task:
|
||||
log.debug("Starting %s...", self.__class__.__name__)
|
||||
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||
return create_task(self._serve_forever())
|
||||
|
||||
|
||||
class UnixControlServer(ControlServer):
|
||||
client_class = UnixControlClient
|
||||
|
||||
def __init__(self, pool: TaskPool, **server_kwargs) -> None:
|
||||
self._socket_path = Path(server_kwargs.pop('path'))
|
||||
super().__init__(pool, **server_kwargs)
|
||||
|
||||
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||
return srv
|
||||
|
||||
def final_callback(self) -> None:
|
||||
self._socket_path.unlink()
|
||||
log.debug("Removed socket '%s'", str(self._socket_path))
|
@ -1,6 +1,9 @@
|
||||
from typing import Callable, Awaitable, Any
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from typing import Tuple, Callable, Awaitable, Union, Any
|
||||
|
||||
|
||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
FinalCallbackT = Callable
|
||||
CancelCallbackT = Callable
|
||||
|
||||
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user