generated from daniil-berg/boilerplate-py
Compare commits
3 Commits
c9e0e2f255
...
ed376b6f82
Author | SHA1 | Date | |
---|---|---|---|
ed376b6f82 | |||
b3b95877fb | |||
6a5c200ae6 |
@ -1 +1,2 @@
|
|||||||
from .pool import TaskPool
|
from .pool import TaskPool
|
||||||
|
from .server import UnixControlServer
|
||||||
|
46
src/asyncio_taskpool/__main__.py
Normal file
46
src/asyncio_taskpool/__main__.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import sys
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from asyncio import run
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from .client import ControlClient, UnixControlClient
|
||||||
|
from .constants import PACKAGE_NAME
|
||||||
|
from .pool import TaskPool
|
||||||
|
from .server import ControlServer
|
||||||
|
|
||||||
|
|
||||||
|
CONN_TYPE = 'conn_type'
|
||||||
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
|
SOCKET_PATH = 'path'
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli() -> Dict[str, Any]:
|
||||||
|
parser = ArgumentParser(
|
||||||
|
prog=PACKAGE_NAME,
|
||||||
|
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
|
||||||
|
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
|
||||||
|
unix_parser.add_argument(
|
||||||
|
SOCKET_PATH,
|
||||||
|
type=Path,
|
||||||
|
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
|
||||||
|
)
|
||||||
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
kwargs = parse_cli()
|
||||||
|
if kwargs[CONN_TYPE] == UNIX:
|
||||||
|
client = UnixControlClient(path=kwargs[SOCKET_PATH])
|
||||||
|
elif kwargs[CONN_TYPE] == TCP:
|
||||||
|
# TODO: Implement the TCP client class
|
||||||
|
client = UnixControlClient(path=kwargs[SOCKET_PATH])
|
||||||
|
else:
|
||||||
|
print("Invalid connection type", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
await client.start()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run(main())
|
63
src/asyncio_taskpool/client.py
Normal file
63
src/asyncio_taskpool/client.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from asyncio_taskpool import constants
|
||||||
|
from asyncio_taskpool.types import ClientConnT
|
||||||
|
|
||||||
|
|
||||||
|
class ControlClient(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
|
self._conn_kwargs = conn_kwargs
|
||||||
|
self._connected: bool = False
|
||||||
|
|
||||||
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
try:
|
||||||
|
msg = input("> ").strip().lower()
|
||||||
|
except EOFError:
|
||||||
|
msg = constants.CLIENT_EXIT
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print()
|
||||||
|
return
|
||||||
|
if msg == constants.CLIENT_EXIT:
|
||||||
|
writer.close()
|
||||||
|
self._connected = False
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
writer.write(msg.encode())
|
||||||
|
await writer.drain()
|
||||||
|
except ConnectionError as e:
|
||||||
|
self._connected = False
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
return
|
||||||
|
print((await reader.read(constants.MSG_BYTES)).decode())
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
reader, writer = await self.open_connection(**self._conn_kwargs)
|
||||||
|
if reader is None:
|
||||||
|
print("Failed to connect.", file=sys.stderr)
|
||||||
|
return
|
||||||
|
self._connected = True
|
||||||
|
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
|
||||||
|
while self._connected:
|
||||||
|
await self._interact(reader, writer)
|
||||||
|
print("Disconnected from control server.")
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlClient(ControlClient):
|
||||||
|
def __init__(self, **conn_kwargs) -> None:
|
||||||
|
self._socket_path = Path(conn_kwargs.pop('path'))
|
||||||
|
super().__init__(**conn_kwargs)
|
||||||
|
|
||||||
|
async def open_connection(self, **kwargs) -> ClientConnT:
|
||||||
|
try:
|
||||||
|
return await open_unix_connection(self._socket_path, **kwargs)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("No socket at", self._socket_path, file=sys.stderr)
|
||||||
|
return None, None
|
7
src/asyncio_taskpool/constants.py
Normal file
7
src/asyncio_taskpool/constants.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
|
MSG_BYTES = 1024
|
||||||
|
CMD_START = 'start'
|
||||||
|
CMD_STOP = 'stop'
|
||||||
|
CMD_STOP_ALL = 'stop_all'
|
||||||
|
CMD_SIZE = 'size'
|
||||||
|
CLIENT_EXIT = 'exit'
|
@ -19,6 +19,7 @@ class TaskPool:
|
|||||||
self._final_callback: FinalCallbackT = final_callback
|
self._final_callback: FinalCallbackT = final_callback
|
||||||
self._cancel_callback: CancelCallbackT = cancel_callback
|
self._cancel_callback: CancelCallbackT = cancel_callback
|
||||||
self._tasks: List[Task] = []
|
self._tasks: List[Task] = []
|
||||||
|
self._cancelled: List[Task] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def func_name(self) -> str:
|
def func_name(self) -> str:
|
||||||
@ -43,11 +44,20 @@ class TaskPool:
|
|||||||
self._start_one()
|
self._start_one()
|
||||||
|
|
||||||
def stop(self, num: int = 1) -> int:
|
def stop(self, num: int = 1) -> int:
|
||||||
if num < 1:
|
for i in range(num):
|
||||||
return 0
|
try:
|
||||||
return sum(task.cancel() for task in reversed(self._tasks[-num:]))
|
task = self._tasks.pop()
|
||||||
|
except IndexError:
|
||||||
|
num = i
|
||||||
|
break
|
||||||
|
task.cancel()
|
||||||
|
self._cancelled.append(task)
|
||||||
|
return num
|
||||||
|
|
||||||
|
def stop_all(self) -> int:
|
||||||
|
return self.stop(self.size)
|
||||||
|
|
||||||
async def gather(self, return_exceptions: bool = False):
|
async def gather(self, return_exceptions: bool = False):
|
||||||
results = await gather(*self._tasks, return_exceptions=return_exceptions)
|
results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions)
|
||||||
self._tasks = []
|
self._tasks = self._cancelled = []
|
||||||
return results
|
return results
|
||||||
|
128
src/asyncio_taskpool/server.py
Normal file
128
src/asyncio_taskpool/server.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from asyncio import AbstractServer
|
||||||
|
from asyncio.exceptions import CancelledError
|
||||||
|
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
||||||
|
from asyncio.tasks import Task, create_task
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Union, Optional
|
||||||
|
|
||||||
|
from . import constants
|
||||||
|
from .pool import TaskPool
|
||||||
|
from .client import ControlClient, UnixControlClient
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def tasks_str(num: int) -> str:
|
||||||
|
return "tasks" if num != 1 else "task"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
|
||||||
|
cmd = msg.strip().split(' ', 1)
|
||||||
|
if len(cmd) > 1:
|
||||||
|
try:
|
||||||
|
return cmd[0], int(cmd[1])
|
||||||
|
except ValueError:
|
||||||
|
return None, None
|
||||||
|
return cmd[0], None
|
||||||
|
|
||||||
|
|
||||||
|
class ControlServer(ABC):
|
||||||
|
client_class = ControlClient
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def final_callback(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, pool: TaskPool, **server_kwargs) -> None:
|
||||||
|
self._pool: TaskPool = pool
|
||||||
|
self._server_kwargs = server_kwargs
|
||||||
|
self._server: Optional[AbstractServer] = None
|
||||||
|
|
||||||
|
def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||||
|
if num is None:
|
||||||
|
num = 1
|
||||||
|
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
|
||||||
|
self._pool.start(num)
|
||||||
|
size = self._pool.size
|
||||||
|
writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode())
|
||||||
|
|
||||||
|
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
|
||||||
|
if num is None:
|
||||||
|
num = 1
|
||||||
|
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
|
||||||
|
num = self._pool.stop(num) # the requested number may be greater than the total number of running tasks
|
||||||
|
size = self._pool.size
|
||||||
|
writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode())
|
||||||
|
|
||||||
|
def _stop_all_tasks(self, writer: StreamWriter) -> None:
|
||||||
|
log.debug("%s requests stopping all tasks", self.client_class.__name__)
|
||||||
|
num = self._pool.stop_all()
|
||||||
|
writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode())
|
||||||
|
|
||||||
|
def _pool_size(self, writer: StreamWriter) -> None:
|
||||||
|
log.debug("%s requests pool size", self.client_class.__name__)
|
||||||
|
writer.write(f'{self._pool.size}'.encode())
|
||||||
|
|
||||||
|
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
while self._server.is_serving():
|
||||||
|
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
|
||||||
|
if not msg:
|
||||||
|
log.debug("%s disconnected", self.client_class.__name__)
|
||||||
|
break
|
||||||
|
cmd, arg = get_cmd_arg(msg)
|
||||||
|
if cmd == constants.CMD_START:
|
||||||
|
self._start_tasks(writer, arg)
|
||||||
|
elif cmd == constants.CMD_STOP:
|
||||||
|
self._stop_tasks(writer, arg)
|
||||||
|
elif cmd == constants.CMD_STOP_ALL:
|
||||||
|
self._stop_all_tasks(writer)
|
||||||
|
elif cmd == constants.CMD_SIZE:
|
||||||
|
self._pool_size(writer)
|
||||||
|
else:
|
||||||
|
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
|
||||||
|
writer.write(b"Invalid command!")
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
|
log.debug("%s connected", self.client_class.__name__)
|
||||||
|
writer.write(f"{self.__class__.__name__} for {self._pool}".encode())
|
||||||
|
await writer.drain()
|
||||||
|
await self._listen(reader, writer)
|
||||||
|
|
||||||
|
async def _serve_forever(self) -> None:
|
||||||
|
try:
|
||||||
|
async with self._server:
|
||||||
|
await self._server.serve_forever()
|
||||||
|
except CancelledError:
|
||||||
|
log.debug("%s stopped", self.__class__.__name__)
|
||||||
|
finally:
|
||||||
|
self.final_callback()
|
||||||
|
|
||||||
|
async def serve_forever(self) -> Task:
|
||||||
|
log.debug("Starting %s...", self.__class__.__name__)
|
||||||
|
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
|
||||||
|
return create_task(self._serve_forever())
|
||||||
|
|
||||||
|
|
||||||
|
class UnixControlServer(ControlServer):
|
||||||
|
client_class = UnixControlClient
|
||||||
|
|
||||||
|
def __init__(self, pool: TaskPool, **server_kwargs) -> None:
|
||||||
|
self._socket_path = Path(server_kwargs.pop('path'))
|
||||||
|
super().__init__(pool, **server_kwargs)
|
||||||
|
|
||||||
|
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
|
||||||
|
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||||
|
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||||
|
return srv
|
||||||
|
|
||||||
|
def final_callback(self) -> None:
|
||||||
|
self._socket_path.unlink()
|
||||||
|
log.debug("Removed socket '%s'", str(self._socket_path))
|
@ -1,6 +1,9 @@
|
|||||||
from typing import Callable, Awaitable, Any
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
|
from typing import Tuple, Callable, Awaitable, Union, Any
|
||||||
|
|
||||||
|
|
||||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||||
FinalCallbackT = Callable
|
FinalCallbackT = Callable
|
||||||
CancelCallbackT = Callable
|
CancelCallbackT = Callable
|
||||||
|
|
||||||
|
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]
|
||||||
|
0
usage/__init__.py
Normal file
0
usage/__init__.py
Normal file
64
usage/example_server.py
Normal file
64
usage/example_server.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from asyncio_taskpool import TaskPool, UnixControlServer
|
||||||
|
from asyncio_taskpool.constants import PACKAGE_NAME
|
||||||
|
|
||||||
|
|
||||||
|
logging.getLogger().setLevel(logging.NOTSET)
|
||||||
|
logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
|
||||||
|
|
||||||
|
|
||||||
|
async def work(item: int) -> None:
|
||||||
|
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print("worked on", item)
|
||||||
|
|
||||||
|
|
||||||
|
async def worker(q: asyncio.Queue) -> None:
|
||||||
|
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
|
||||||
|
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
|
||||||
|
while True:
|
||||||
|
# We want to block here, until we can get the next item from the queue.
|
||||||
|
item = await q.get()
|
||||||
|
# Since we want a nice cleanup upon cancellation, we put the "work" to be done in a `try:` block.
|
||||||
|
try:
|
||||||
|
await work(item)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# If the task gets cancelled before our current "work" item is finished, we put it back into the queue
|
||||||
|
# because a worker must assume that some other worker can and will eventually finish the work on that item.
|
||||||
|
q.put_nowait(item)
|
||||||
|
# This takes us out of the loop. To enable cleanup we must re-raise the exception.
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Since putting an item into the queue (even if it has just been taken out), increments the internal
|
||||||
|
# `._unfinished_tasks` counter in the queue, we must ensure that it is decremented before we end the
|
||||||
|
# iteration or leave the loop. Otherwise, the queue's `.join()` will block indefinitely.
|
||||||
|
q.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
# First, we set up a queue of items that our workers can "work" on.
|
||||||
|
q = asyncio.Queue()
|
||||||
|
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
|
||||||
|
for item in range(100):
|
||||||
|
q.put_nowait(item)
|
||||||
|
pool = TaskPool(worker, (q,)) # initializes the pool
|
||||||
|
pool.start(3) # launches three worker tasks
|
||||||
|
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
|
||||||
|
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||||
|
await q.join()
|
||||||
|
# Since we don't need any "work" done anymore, we can close our control server by cancelling the task.
|
||||||
|
control_server_task.cancel()
|
||||||
|
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
||||||
|
# we can now safely cancel their tasks.
|
||||||
|
pool.stop_all()
|
||||||
|
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
|
||||||
|
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
||||||
|
# we just silently collect their exceptions along with their return values.
|
||||||
|
await pool.gather(return_exceptions=True)
|
||||||
|
await control_server_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user