implemented TCP socket control; switched example to TCP

This commit is contained in:
Daniil Fajnberg 2022-02-25 22:42:37 +01:00
parent ae6bb1bd17
commit 5dad4ab0c7
5 changed files with 70 additions and 22 deletions

View File

@ -20,4 +20,4 @@ Brings the main classes up to package level for import convenience.
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer from .server import TCPControlServer, UnixControlServer

View File

@ -25,15 +25,16 @@ from asyncio import run
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
from .client import ControlClient, UnixControlClient from .client import ControlClient, TCPControlClient, UnixControlClient
from .constants import PACKAGE_NAME from .constants import PACKAGE_NAME
from .pool import TaskPool from .pool import TaskPool
from .server import ControlServer from .server import TCPControlServer, UnixControlServer
CONN_TYPE = 'conn_type' CONN_TYPE = 'conn_type'
UNIX, TCP = 'unix', 'tcp' UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path' SOCKET_PATH = 'path'
HOST, PORT = 'host', 'port'
def parse_cli() -> Dict[str, Any]: def parse_cli() -> Dict[str, Any]:
@ -46,7 +47,18 @@ def parse_cli() -> Dict[str, Any]:
unix_parser.add_argument( unix_parser.add_argument(
SOCKET_PATH, SOCKET_PATH,
type=Path, type=Path,
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening." help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
f"listening."
)
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
tcp_parser.add_argument(
HOST,
help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
)
tcp_parser.add_argument(
PORT,
type=int,
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
) )
return vars(parser.parse_args()) return vars(parser.parse_args())
@ -56,8 +68,7 @@ async def main():
if kwargs[CONN_TYPE] == UNIX: if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP: elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
else: else:
print("Invalid connection type", file=sys.stderr) print("Invalid connection type", file=sys.stderr)
sys.exit(2) sys.exit(2)

View File

@ -23,9 +23,9 @@ import json
import shutil import shutil
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter, open_connection
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
from .types import ClientConnT, PathT from .types import ClientConnT, PathT
@ -50,8 +50,8 @@ class ControlClient(ABC):
""" """
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair. Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked) This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs`
as keyword-arguments. (unpacked) as keyword-arguments.
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
`None` and `None`, if it failed to establish the defined connection. `None` and `None`, if it failed to establish the defined connection.
""" """
@ -144,15 +144,34 @@ class ControlClient(ABC):
print("Disconnected from control server.") print("Disconnected from control server.")
class TCPControlClient(ControlClient):
"""Task pool control client that expects a TCP socket to be exposed by the control server."""
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments."""
self._host = host
self._port = port
super().__init__(**conn_kwargs)
async def _open_connection(self, **kwargs) -> ClientConnT:
"""
Wrapper around the `asyncio.open_connection` function.
Returns a tuple of `None` and `None`, if the connection can not be established;
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await open_connection(self._host, self._port, **kwargs)
except ConnectionError as e:
print(str(e), file=sys.stderr)
return None, None
class UnixControlClient(ControlClient): class UnixControlClient(ControlClient):
"""Task pool control client that expects a unix socket to be exposed by the control server.""" """Task pool control client that expects a unix socket to be exposed by the control server."""
def __init__(self, socket_path: PathT, **conn_kwargs) -> None: def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
""" """In addition to what the base class does, the `socket_path` is expected as a non-optional argument."""
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
"""
from asyncio.streams import open_unix_connection from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path) self._socket_path = Path(socket_path)

View File

@ -23,12 +23,12 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import AbstractServer from asyncio import AbstractServer
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter, start_server
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from .client import ControlClient, UnixControlClient from .client import ControlClient, TCPControlClient, UnixControlClient
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool
from .session import ControlSession from .session import ControlSession
from .types import ConnectedCallbackT from .types import ConnectedCallbackT
@ -132,6 +132,24 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
return create_task(self._serve_forever()) return create_task(self._serve_forever())
class TCPControlServer(ControlServer):
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
_client_class = TCPControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._host = server_kwargs.pop('host')
self._port = server_kwargs.pop('port')
super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_server(client_connected_cb, self._host, self._port, **kwargs)
log.debug("Opened socket at %s:%s", self._host, self._port)
return server
def _final_callback(self) -> None:
log.debug("Closed socket at %s:%s", self._host, self._port)
class UnixControlServer(ControlServer): class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to.""" """Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient _client_class = UnixControlClient

View File

@ -23,7 +23,7 @@ Use the main CLI client to interface at the socket.
import asyncio import asyncio
import logging import logging
from asyncio_taskpool import SimpleTaskPool, UnixControlServer from asyncio_taskpool import SimpleTaskPool, TCPControlServer
from asyncio_taskpool.constants import PACKAGE_NAME from asyncio_taskpool.constants import PACKAGE_NAME
@ -34,11 +34,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
async def work(item: int) -> None: async def work(item: int) -> None:
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously.""" """The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
await asyncio.sleep(1) await asyncio.sleep(1)
print("worked on", item) print("worked on", item, flush=True)
async def worker(q: asyncio.Queue) -> None: async def worker(q: asyncio.Queue) -> None:
"""Simulates doing asynchronous work that takes a little bit of time to finish.""" """Simulates doing asynchronous work that takes a bit of time to finish."""
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop. # We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
while True: while True:
# We want to block here, until we can get the next item from the queue. # We want to block here, until we can get the next item from the queue.
@ -67,7 +67,7 @@ async def main() -> None:
q.put_nowait(item) q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool pool = SimpleTaskPool(worker, (q,)) # initializes the pool
await pool.start(3) # launches three worker tasks await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue. # We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join() await q.join()
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task. # Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
@ -76,7 +76,7 @@ async def main() -> None:
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.lock() pool.lock()
pool.stop_all() pool.stop_all()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values. # we just silently collect their exceptions along with their return values.
await pool.gather_and_close(return_exceptions=True) await pool.gather_and_close(return_exceptions=True)