implemented TCP socket control; switched example to TCP

This commit is contained in:
2022-02-25 22:42:37 +01:00
parent ae6bb1bd17
commit 5dad4ab0c7
5 changed files with 70 additions and 22 deletions

View File

@ -20,4 +20,4 @@ Brings the main classes up to package level for import convenience.
from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer
from .server import TCPControlServer, UnixControlServer

View File

@ -25,15 +25,16 @@ from asyncio import run
from pathlib import Path
from typing import Dict, Any
from .client import ControlClient, UnixControlClient
from .client import ControlClient, TCPControlClient, UnixControlClient
from .constants import PACKAGE_NAME
from .pool import TaskPool
from .server import ControlServer
from .server import TCPControlServer, UnixControlServer
CONN_TYPE = 'conn_type'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port'
def parse_cli() -> Dict[str, Any]:
@ -46,7 +47,18 @@ def parse_cli() -> Dict[str, Any]:
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument(
HOST,
help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
tcp_parser.add_argument(
PORT,
type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
return vars(parser.parse_args())
@ -56,8 +68,7 @@ async def main():
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)

View File

@ -23,9 +23,9 @@ import json
import shutil
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter
from asyncio.streams import StreamReader, StreamWriter, open_connection
from pathlib import Path
from typing import Optional
from typing import Optional, Union
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT, PathT
@ -50,8 +50,8 @@ class ControlClient(ABC):
"""
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
as keyword-arguments.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs`
(unpacked) as keyword-arguments.
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
`None` and `None`, if it failed to establish the defined connection.
"""
@ -144,15 +144,34 @@ class ControlClient(ABC):
print("Disconnected from control server.")
class TCPControlClient(ControlClient):
"""Task pool control client that expects a TCP socket to be exposed by the control server."""
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments."""
self._host = host
self._port = port
super().__init__(**conn_kwargs)
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_connection` function.
Returns a tuple of `None` and `None`, if the connection can not be established;
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await open_connection(self._host, self._port, **kwargs)
except ConnectionError as e:
print(str(e), file=sys.stderr)
return None, None
class UnixControlClient(ControlClient):
"""Task pool control client that expects a unix socket to be exposed by the control server."""
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
"""
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
"""
"""In addition to what the base class does, the `socket_path` is expected as a non-optional argument."""
from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path)

View File

@ -23,12 +23,12 @@ import logging
from abc import ABC, abstractmethod
from asyncio import AbstractServer
from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter
from asyncio.streams import StreamReader, StreamWriter, start_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Optional, Union
from .client import ControlClient, UnixControlClient
from .client import ControlClient, TCPControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession
from .types import ConnectedCallbackT
@ -132,6 +132,24 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
return create_task(self._serve_forever())
class TCPControlServer(ControlServer):
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
_client_class = TCPControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._host = server_kwargs.pop('host')
self._port = server_kwargs.pop('port')
super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_server(client_connected_cb, self._host, self._port, **kwargs)
log.debug("Opened socket at %s:%s", self._host, self._port)
return server
def _final_callback(self) -> None:
log.debug("Closed socket at %s:%s", self._host, self._port)
class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient