diff --git a/src/asyncio_taskpool/__init__.py b/src/asyncio_taskpool/__init__.py index d28d142..c3b03b7 100644 --- a/src/asyncio_taskpool/__init__.py +++ b/src/asyncio_taskpool/__init__.py @@ -20,4 +20,4 @@ Brings the main classes up to package level for import convenience. from .pool import TaskPool, SimpleTaskPool -from .server import UnixControlServer +from .server import TCPControlServer, UnixControlServer diff --git a/src/asyncio_taskpool/__main__.py b/src/asyncio_taskpool/__main__.py index 9e8c59b..5c9e97e 100644 --- a/src/asyncio_taskpool/__main__.py +++ b/src/asyncio_taskpool/__main__.py @@ -25,15 +25,16 @@ from asyncio import run from pathlib import Path from typing import Dict, Any -from .client import ControlClient, UnixControlClient +from .client import ControlClient, TCPControlClient, UnixControlClient from .constants import PACKAGE_NAME from .pool import TaskPool -from .server import ControlServer +from .server import TCPControlServer, UnixControlServer CONN_TYPE = 'conn_type' UNIX, TCP = 'unix', 'tcp' SOCKET_PATH = 'path' +HOST, PORT = 'host', 'port' def parse_cli() -> Dict[str, Any]: @@ -46,7 +47,18 @@ def parse_cli() -> Dict[str, Any]: unix_parser.add_argument( SOCKET_PATH, type=Path, - help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening." + help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is " + f"listening." + ) + tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket") + tcp_parser.add_argument( + HOST, + help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." + ) + tcp_parser.add_argument( + PORT, + type=int, + help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." ) return vars(parser.parse_args()) @@ -56,8 +68,7 @@ async def main(): if kwargs[CONN_TYPE] == UNIX: client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) elif kwargs[CONN_TYPE] == TCP: - # TODO: Implement the TCP client class - client = UnixControlClient(socket_path=kwargs[SOCKET_PATH]) + client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT]) else: print("Invalid connection type", file=sys.stderr) sys.exit(2) diff --git a/src/asyncio_taskpool/client.py b/src/asyncio_taskpool/client.py index eab4c37..9a08127 100644 --- a/src/asyncio_taskpool/client.py +++ b/src/asyncio_taskpool/client.py @@ -23,9 +23,9 @@ import json import shutil import sys from abc import ABC, abstractmethod -from asyncio.streams import StreamReader, StreamWriter +from asyncio.streams import StreamReader, StreamWriter, open_connection from pathlib import Path -from typing import Optional +from typing import Optional, Union from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES from .types import ClientConnT, PathT @@ -50,8 +50,8 @@ class ControlClient(ABC): """ Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair. - This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked) - as keyword-arguments. + This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` + (unpacked) as keyword-arguments. This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of `None` and `None`, if it failed to establish the defined connection. """ @@ -144,15 +144,34 @@ class ControlClient(ABC): print("Disconnected from control server.") +class TCPControlClient(ControlClient): + """Task pool control client that expects a TCP socket to be exposed by the control server.""" + + def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None: + """In addition to what the base class does, `host` and `port` are expected as non-optional arguments.""" + self._host = host + self._port = port + super().__init__(**conn_kwargs) + + async def _open_connection(self, **kwargs) -> ClientConnT: + """ + Wrapper around the `asyncio.open_connection` function. + + Returns a tuple of `None` and `None`, if the connection can not be established; + otherwise, the stream-reader and -writer tuple is returned. + """ + try: + return await open_connection(self._host, self._port, **kwargs) + except ConnectionError as e: + print(str(e), file=sys.stderr) + return None, None + + class UnixControlClient(ControlClient): """Task pool control client that expects a unix socket to be exposed by the control server.""" def __init__(self, socket_path: PathT, **conn_kwargs) -> None: - """ - In addition to what the base class does, the `socket_path` is expected as a non-optional argument. - - The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument. - """ + """In addition to what the base class does, the `socket_path` is expected as a non-optional argument.""" from asyncio.streams import open_unix_connection self._open_unix_connection = open_unix_connection self._socket_path = Path(socket_path) diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index d7a0f4c..7b020f5 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -23,12 +23,12 @@ import logging from abc import ABC, abstractmethod from asyncio import AbstractServer from asyncio.exceptions import CancelledError -from asyncio.streams import StreamReader, StreamWriter +from asyncio.streams import StreamReader, StreamWriter, start_server from asyncio.tasks import Task, create_task from pathlib import Path from typing import Optional, Union -from .client import ControlClient, UnixControlClient +from .client import ControlClient, TCPControlClient, UnixControlClient from .pool import TaskPool, SimpleTaskPool from .session import ControlSession from .types import ConnectedCallbackT @@ -132,6 +132,24 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta return create_task(self._serve_forever()) +class TCPControlServer(ControlServer): + """Task pool control server class that exposes a TCP socket for control clients to connect to.""" + _client_class = TCPControlClient + + def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: + self._host = server_kwargs.pop('host') + self._port = server_kwargs.pop('port') + super().__init__(pool, **server_kwargs) + + async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: + server = await start_server(client_connected_cb, self._host, self._port, **kwargs) + log.debug("Opened socket at %s:%s", self._host, self._port) + return server + + def _final_callback(self) -> None: + log.debug("Closed socket at %s:%s", self._host, self._port) + + class UnixControlServer(ControlServer): """Task pool control server class that exposes a unix socket for control clients to connect to.""" _client_class = UnixControlClient diff --git a/usage/example_server.py b/usage/example_server.py index f4a30e0..6d819c4 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -23,7 +23,7 @@ Use the main CLI client to interface at the socket. import asyncio import logging -from asyncio_taskpool import SimpleTaskPool, UnixControlServer +from asyncio_taskpool import SimpleTaskPool, TCPControlServer from asyncio_taskpool.constants import PACKAGE_NAME @@ -34,11 +34,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler()) async def work(item: int) -> None: """The non-blocking sleep simulates something like an I/O operation that can be done asynchronously.""" await asyncio.sleep(1) - print("worked on", item) + print("worked on", item, flush=True) async def worker(q: asyncio.Queue) -> None: - """Simulates doing asynchronous work that takes a little bit of time to finish.""" + """Simulates doing asynchronous work that takes a bit of time to finish.""" # We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop. while True: # We want to block here, until we can get the next item from the queue. @@ -67,7 +67,7 @@ async def main() -> None: q.put_nowait(item) pool = SimpleTaskPool(worker, (q,)) # initializes the pool await pool.start(3) # launches three worker tasks - control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() + control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever() # We block until `.task_done()` has been called once by our workers for every item placed into the queue. await q.join() # Since we don't need any "work" done anymore, we can lock our control server by cancelling the task. @@ -76,7 +76,7 @@ async def main() -> None: # we can now safely cancel their tasks. pool.lock() pool.stop_all() - # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. + # Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled. # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # we just silently collect their exceptions along with their return values. await pool.gather_and_close(return_exceptions=True)