2022-02-09 23:14:42 +01:00
|
|
|
__author__ = "Daniil Fajnberg"
|
|
|
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
|
|
|
__license__ = """GNU LGPLv3.0
|
|
|
|
|
|
|
|
This file is part of asyncio-taskpool.
|
|
|
|
|
|
|
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
|
|
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
|
|
|
|
|
|
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
|
|
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
|
|
See the GNU Lesser General Public License for more details.
|
|
|
|
|
|
|
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
|
|
|
If not, see <https://www.gnu.org/licenses/>."""
|
|
|
|
|
|
|
|
__doc__ = """
|
|
|
|
This module contains the task pool control server class definitions.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
2022-02-04 16:25:09 +01:00
|
|
|
import logging
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from asyncio import AbstractServer
|
|
|
|
from asyncio.exceptions import CancelledError
|
|
|
|
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
|
|
|
from asyncio.tasks import Task, create_task
|
|
|
|
from pathlib import Path
|
2022-02-12 22:51:52 +01:00
|
|
|
from typing import Optional, Union
|
2022-02-04 16:25:09 +01:00
|
|
|
|
|
|
|
from .client import ControlClient, UnixControlClient
|
2022-02-12 22:51:52 +01:00
|
|
|
from .pool import TaskPool, SimpleTaskPool
|
|
|
|
from .session import ControlSession
|
2022-02-13 16:17:50 +01:00
|
|
|
from .types import ConnectedCallbackT
|
2022-02-04 16:25:09 +01:00
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2022-02-05 18:02:32 +01:00
|
|
|
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
|
2022-02-13 16:17:50 +01:00
|
|
|
"""
|
|
|
|
Abstract base class for a task pool control server.
|
|
|
|
|
|
|
|
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client
|
|
|
|
connecting to it. The entire interface is defined within that session class.
|
|
|
|
"""
|
2022-02-12 22:51:52 +01:00
|
|
|
_client_class = ControlClient
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@property
|
|
|
|
def client_class_name(cls) -> str:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""Returns the name of the control client class matching the server class."""
|
2022-02-12 22:51:52 +01:00
|
|
|
return cls._client_class.__name__
|
2022-02-04 16:25:09 +01:00
|
|
|
|
|
|
|
@abstractmethod
|
2022-02-13 16:17:50 +01:00
|
|
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
|
|
|
"""
|
|
|
|
Initializes, starts, and returns an async server instance (Unix or TCP type).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
client_connected_cb:
|
|
|
|
The callback for when a client connects to the server (as per `asyncio.start_server` or
|
|
|
|
`asyncio.start_unix_server`); will always be the internal `_client_connected_cb` method.
|
|
|
|
**kwargs (optional):
|
|
|
|
Keyword arguments to pass into the function that starts the server.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The running server object (a type of `asyncio.Server`).
|
|
|
|
"""
|
2022-02-04 16:25:09 +01:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@abstractmethod
|
2022-02-13 16:17:50 +01:00
|
|
|
def _final_callback(self) -> None:
|
|
|
|
"""The method to run after the server's `serve_forever` methods ends for whatever reason."""
|
2022-02-04 16:25:09 +01:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2022-02-12 22:51:52 +01:00
|
|
|
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""
|
|
|
|
Initializes by merely saving the internal attributes, but without starting the server yet.
|
|
|
|
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
|
|
|
|
tied to one specific task pool.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
pool:
|
|
|
|
An instance of a `BaseTaskPool` subclass to tie the server to.
|
|
|
|
**server_kwargs (optional):
|
|
|
|
Keyword arguments that will be passed into the function that starts the server.
|
|
|
|
"""
|
2022-02-12 22:51:52 +01:00
|
|
|
self._pool: Union[TaskPool, SimpleTaskPool] = pool
|
2022-02-04 16:25:09 +01:00
|
|
|
self._server_kwargs = server_kwargs
|
|
|
|
self._server: Optional[AbstractServer] = None
|
|
|
|
|
2022-02-12 22:51:52 +01:00
|
|
|
@property
|
|
|
|
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""Read-only property for accessing the task pool instance controlled by the server."""
|
2022-02-12 22:51:52 +01:00
|
|
|
return self._pool
|
|
|
|
|
|
|
|
def is_serving(self) -> bool:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""Wrapper around the `asyncio.Server.is_serving` method."""
|
2022-02-12 22:51:52 +01:00
|
|
|
return self._server.is_serving()
|
2022-02-04 16:25:09 +01:00
|
|
|
|
2022-02-04 17:41:10 +01:00
|
|
|
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""
|
|
|
|
The universal client callback that will be passed into the `_get_server_instance` method.
|
|
|
|
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
|
|
|
|
"""
|
2022-02-12 22:51:52 +01:00
|
|
|
session = ControlSession(self, reader, writer)
|
|
|
|
await session.client_handshake()
|
|
|
|
await session.listen()
|
2022-02-04 17:41:10 +01:00
|
|
|
|
2022-02-04 16:25:09 +01:00
|
|
|
async def _serve_forever(self) -> None:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""
|
|
|
|
To be run as an `asyncio.Task` by the following method.
|
|
|
|
Serves as a wrapper around the the `asyncio.Server.serve_forever` method that ensures that the `_final_callback`
|
|
|
|
method is called, when the former method ends for whatever reason.
|
|
|
|
"""
|
2022-02-04 16:25:09 +01:00
|
|
|
try:
|
|
|
|
async with self._server:
|
|
|
|
await self._server.serve_forever()
|
|
|
|
except CancelledError:
|
|
|
|
log.debug("%s stopped", self.__class__.__name__)
|
|
|
|
finally:
|
2022-02-13 16:17:50 +01:00
|
|
|
self._final_callback()
|
2022-02-04 16:25:09 +01:00
|
|
|
|
|
|
|
async def serve_forever(self) -> Task:
|
2022-02-13 16:17:50 +01:00
|
|
|
"""
|
|
|
|
This method actually starts the server and begins listening to client connections on the specified interface.
|
|
|
|
It should never block because the serving will be performed in a separate task.
|
|
|
|
"""
|
2022-02-04 16:25:09 +01:00
|
|
|
log.debug("Starting %s...", self.__class__.__name__)
|
2022-02-13 16:17:50 +01:00
|
|
|
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
2022-02-04 16:25:09 +01:00
|
|
|
return create_task(self._serve_forever())
|
|
|
|
|
|
|
|
|
|
|
|
class UnixControlServer(ControlServer):
|
2022-02-13 16:17:50 +01:00
|
|
|
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
2022-02-12 22:51:52 +01:00
|
|
|
_client_class = UnixControlClient
|
2022-02-04 16:25:09 +01:00
|
|
|
|
2022-02-05 18:02:32 +01:00
|
|
|
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
2022-02-04 16:25:09 +01:00
|
|
|
self._socket_path = Path(server_kwargs.pop('path'))
|
|
|
|
super().__init__(pool, **server_kwargs)
|
|
|
|
|
2022-02-13 16:17:50 +01:00
|
|
|
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
|
|
|
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
2022-02-04 16:25:09 +01:00
|
|
|
log.debug("Opened socket '%s'", str(self._socket_path))
|
2022-02-13 16:17:50 +01:00
|
|
|
return server
|
2022-02-04 16:25:09 +01:00
|
|
|
|
2022-02-13 16:17:50 +01:00
|
|
|
def _final_callback(self) -> None:
|
|
|
|
"""Removes the unix socket on which the server was listening."""
|
2022-02-04 16:25:09 +01:00
|
|
|
self._socket_path.unlink()
|
|
|
|
log.debug("Removed socket '%s'", str(self._socket_path))
|