generated from daniil-berg/boilerplate-py
	Compare commits
	
		
			3 Commits
		
	
	
		
			c9e0e2f255
			...
			ed376b6f82
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ed376b6f82 | |||
| b3b95877fb | |||
| 6a5c200ae6 | 
| @@ -1 +1,2 @@ | |||||||
| from .pool import TaskPool | from .pool import TaskPool | ||||||
|  | from .server import UnixControlServer | ||||||
|   | |||||||
							
								
								
									
										46
									
								
								src/asyncio_taskpool/__main__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								src/asyncio_taskpool/__main__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | import sys | ||||||
|  | from argparse import ArgumentParser | ||||||
|  | from asyncio import run | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Dict, Any | ||||||
|  |  | ||||||
|  | from .client import ControlClient, UnixControlClient | ||||||
|  | from .constants import PACKAGE_NAME | ||||||
|  | from .pool import TaskPool | ||||||
|  | from .server import ControlServer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | CONN_TYPE = 'conn_type' | ||||||
|  | UNIX, TCP = 'unix', 'tcp' | ||||||
|  | SOCKET_PATH = 'path' | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_cli() -> Dict[str, Any]: | ||||||
|  |     parser = ArgumentParser( | ||||||
|  |         prog=PACKAGE_NAME, | ||||||
|  |         description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}" | ||||||
|  |     ) | ||||||
|  |     subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE) | ||||||
|  |     unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket") | ||||||
|  |     unix_parser.add_argument( | ||||||
|  |         SOCKET_PATH, | ||||||
|  |         type=Path, | ||||||
|  |         help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening." | ||||||
|  |     ) | ||||||
|  |     return vars(parser.parse_args()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def main(): | ||||||
|  |     kwargs = parse_cli() | ||||||
|  |     if kwargs[CONN_TYPE] == UNIX: | ||||||
|  |         client = UnixControlClient(path=kwargs[SOCKET_PATH]) | ||||||
|  |     elif kwargs[CONN_TYPE] == TCP: | ||||||
|  |         # TODO: Implement the TCP client class | ||||||
|  |         client = UnixControlClient(path=kwargs[SOCKET_PATH]) | ||||||
|  |     else: | ||||||
|  |         print("Invalid connection type", file=sys.stderr) | ||||||
|  |         sys.exit(2) | ||||||
|  |     await client.start() | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     run(main()) | ||||||
							
								
								
									
										63
									
								
								src/asyncio_taskpool/client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								src/asyncio_taskpool/client.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | import sys | ||||||
|  | from abc import ABC, abstractmethod | ||||||
|  | from asyncio.streams import StreamReader, StreamWriter, open_unix_connection | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from asyncio_taskpool import constants | ||||||
|  | from asyncio_taskpool.types import ClientConnT | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ControlClient(ABC): | ||||||
|  |  | ||||||
|  |     @abstractmethod | ||||||
|  |     async def open_connection(self, **kwargs) -> ClientConnT: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __init__(self, **conn_kwargs) -> None: | ||||||
|  |         self._conn_kwargs = conn_kwargs | ||||||
|  |         self._connected: bool = False | ||||||
|  |  | ||||||
|  |     async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: | ||||||
|  |         try: | ||||||
|  |             msg = input("> ").strip().lower() | ||||||
|  |         except EOFError: | ||||||
|  |             msg = constants.CLIENT_EXIT | ||||||
|  |         except KeyboardInterrupt: | ||||||
|  |             print() | ||||||
|  |             return | ||||||
|  |         if msg == constants.CLIENT_EXIT: | ||||||
|  |             writer.close() | ||||||
|  |             self._connected = False | ||||||
|  |             return | ||||||
|  |         try: | ||||||
|  |             writer.write(msg.encode()) | ||||||
|  |             await writer.drain() | ||||||
|  |         except ConnectionError as e: | ||||||
|  |             self._connected = False | ||||||
|  |             print(e, file=sys.stderr) | ||||||
|  |             return | ||||||
|  |         print((await reader.read(constants.MSG_BYTES)).decode()) | ||||||
|  |  | ||||||
|  |     async def start(self): | ||||||
|  |         reader, writer = await self.open_connection(**self._conn_kwargs) | ||||||
|  |         if reader is None: | ||||||
|  |             print("Failed to connect.", file=sys.stderr) | ||||||
|  |             return | ||||||
|  |         self._connected = True | ||||||
|  |         print("Connected to", (await reader.read(constants.MSG_BYTES)).decode()) | ||||||
|  |         while self._connected: | ||||||
|  |             await self._interact(reader, writer) | ||||||
|  |         print("Disconnected from control server.") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UnixControlClient(ControlClient): | ||||||
|  |     def __init__(self, **conn_kwargs) -> None: | ||||||
|  |         self._socket_path = Path(conn_kwargs.pop('path')) | ||||||
|  |         super().__init__(**conn_kwargs) | ||||||
|  |  | ||||||
|  |     async def open_connection(self, **kwargs) -> ClientConnT: | ||||||
|  |         try: | ||||||
|  |             return await open_unix_connection(self._socket_path, **kwargs) | ||||||
|  |         except FileNotFoundError: | ||||||
|  |             print("No socket at", self._socket_path, file=sys.stderr) | ||||||
|  |             return None, None | ||||||
							
								
								
									
										7
									
								
								src/asyncio_taskpool/constants.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								src/asyncio_taskpool/constants.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | PACKAGE_NAME = 'asyncio_taskpool' | ||||||
|  | MSG_BYTES = 1024 | ||||||
|  | CMD_START = 'start' | ||||||
|  | CMD_STOP = 'stop' | ||||||
|  | CMD_STOP_ALL = 'stop_all' | ||||||
|  | CMD_SIZE = 'size' | ||||||
|  | CLIENT_EXIT = 'exit' | ||||||
| @@ -19,6 +19,7 @@ class TaskPool: | |||||||
|         self._final_callback: FinalCallbackT = final_callback |         self._final_callback: FinalCallbackT = final_callback | ||||||
|         self._cancel_callback: CancelCallbackT = cancel_callback |         self._cancel_callback: CancelCallbackT = cancel_callback | ||||||
|         self._tasks: List[Task] = [] |         self._tasks: List[Task] = [] | ||||||
|  |         self._cancelled: List[Task] = [] | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def func_name(self) -> str: |     def func_name(self) -> str: | ||||||
| @@ -43,11 +44,20 @@ class TaskPool: | |||||||
|             self._start_one() |             self._start_one() | ||||||
|  |  | ||||||
|     def stop(self, num: int = 1) -> int: |     def stop(self, num: int = 1) -> int: | ||||||
|         if num < 1: |         for i in range(num): | ||||||
|             return 0 |             try: | ||||||
|         return sum(task.cancel() for task in reversed(self._tasks[-num:])) |                 task = self._tasks.pop() | ||||||
|  |             except IndexError: | ||||||
|  |                 num = i | ||||||
|  |                 break | ||||||
|  |             task.cancel() | ||||||
|  |             self._cancelled.append(task) | ||||||
|  |         return num | ||||||
|  |  | ||||||
|  |     def stop_all(self) -> int: | ||||||
|  |         return self.stop(self.size) | ||||||
|  |  | ||||||
|     async def gather(self, return_exceptions: bool = False): |     async def gather(self, return_exceptions: bool = False): | ||||||
|         results = await gather(*self._tasks, return_exceptions=return_exceptions) |         results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions) | ||||||
|         self._tasks = [] |         self._tasks = self._cancelled = [] | ||||||
|         return results |         return results | ||||||
|   | |||||||
							
								
								
									
										128
									
								
								src/asyncio_taskpool/server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								src/asyncio_taskpool/server.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | |||||||
|  | import logging | ||||||
|  | from abc import ABC, abstractmethod | ||||||
|  | from asyncio import AbstractServer | ||||||
|  | from asyncio.exceptions import CancelledError | ||||||
|  | from asyncio.streams import StreamReader, StreamWriter, start_unix_server | ||||||
|  | from asyncio.tasks import Task, create_task | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Tuple, Union, Optional | ||||||
|  |  | ||||||
|  | from . import constants | ||||||
|  | from .pool import TaskPool | ||||||
|  | from .client import ControlClient, UnixControlClient | ||||||
|  |  | ||||||
|  |  | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def tasks_str(num: int) -> str: | ||||||
|  |     return "tasks" if num != 1 else "task" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]: | ||||||
|  |     cmd = msg.strip().split(' ', 1) | ||||||
|  |     if len(cmd) > 1: | ||||||
|  |         try: | ||||||
|  |             return cmd[0], int(cmd[1]) | ||||||
|  |         except ValueError: | ||||||
|  |             return None, None | ||||||
|  |     return cmd[0], None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ControlServer(ABC): | ||||||
|  |     client_class = ControlClient | ||||||
|  |  | ||||||
|  |     @abstractmethod | ||||||
|  |     async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @abstractmethod | ||||||
|  |     def final_callback(self) -> None: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __init__(self, pool: TaskPool, **server_kwargs) -> None: | ||||||
|  |         self._pool: TaskPool = pool | ||||||
|  |         self._server_kwargs = server_kwargs | ||||||
|  |         self._server: Optional[AbstractServer] = None | ||||||
|  |  | ||||||
|  |     def _start_tasks(self, writer: StreamWriter, num: int = None) -> None: | ||||||
|  |         if num is None: | ||||||
|  |             num = 1 | ||||||
|  |         log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) | ||||||
|  |         self._pool.start(num) | ||||||
|  |         size = self._pool.size | ||||||
|  |         writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode()) | ||||||
|  |  | ||||||
|  |     def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: | ||||||
|  |         if num is None: | ||||||
|  |             num = 1 | ||||||
|  |         log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num)) | ||||||
|  |         num = self._pool.stop(num)  # the requested number may be greater than the total number of running tasks | ||||||
|  |         size = self._pool.size | ||||||
|  |         writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode()) | ||||||
|  |  | ||||||
|  |     def _stop_all_tasks(self, writer: StreamWriter) -> None: | ||||||
|  |         log.debug("%s requests stopping all tasks", self.client_class.__name__) | ||||||
|  |         num = self._pool.stop_all() | ||||||
|  |         writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode()) | ||||||
|  |  | ||||||
|  |     def _pool_size(self, writer: StreamWriter) -> None: | ||||||
|  |         log.debug("%s requests pool size", self.client_class.__name__) | ||||||
|  |         writer.write(f'{self._pool.size}'.encode()) | ||||||
|  |  | ||||||
|  |     async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None: | ||||||
|  |         while self._server.is_serving(): | ||||||
|  |             msg = (await reader.read(constants.MSG_BYTES)).decode().strip() | ||||||
|  |             if not msg: | ||||||
|  |                 log.debug("%s disconnected", self.client_class.__name__) | ||||||
|  |                 break | ||||||
|  |             cmd, arg = get_cmd_arg(msg) | ||||||
|  |             if cmd == constants.CMD_START: | ||||||
|  |                 self._start_tasks(writer, arg) | ||||||
|  |             elif cmd == constants.CMD_STOP: | ||||||
|  |                 self._stop_tasks(writer, arg) | ||||||
|  |             elif cmd == constants.CMD_STOP_ALL: | ||||||
|  |                 self._stop_all_tasks(writer) | ||||||
|  |             elif cmd == constants.CMD_SIZE: | ||||||
|  |                 self._pool_size(writer) | ||||||
|  |             else: | ||||||
|  |                 log.debug("%s sent invalid command: %s", self.client_class.__name__, msg) | ||||||
|  |                 writer.write(b"Invalid command!") | ||||||
|  |             await writer.drain() | ||||||
|  |  | ||||||
|  |     async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: | ||||||
|  |         log.debug("%s connected", self.client_class.__name__) | ||||||
|  |         writer.write(f"{self.__class__.__name__} for {self._pool}".encode()) | ||||||
|  |         await writer.drain() | ||||||
|  |         await self._listen(reader, writer) | ||||||
|  |  | ||||||
|  |     async def _serve_forever(self) -> None: | ||||||
|  |         try: | ||||||
|  |             async with self._server: | ||||||
|  |                 await self._server.serve_forever() | ||||||
|  |         except CancelledError: | ||||||
|  |             log.debug("%s stopped", self.__class__.__name__) | ||||||
|  |         finally: | ||||||
|  |             self.final_callback() | ||||||
|  |  | ||||||
|  |     async def serve_forever(self) -> Task: | ||||||
|  |         log.debug("Starting %s...", self.__class__.__name__) | ||||||
|  |         self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs) | ||||||
|  |         return create_task(self._serve_forever()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UnixControlServer(ControlServer): | ||||||
|  |     client_class = UnixControlClient | ||||||
|  |  | ||||||
|  |     def __init__(self, pool: TaskPool, **server_kwargs) -> None: | ||||||
|  |         self._socket_path = Path(server_kwargs.pop('path')) | ||||||
|  |         super().__init__(pool, **server_kwargs) | ||||||
|  |  | ||||||
|  |     async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer: | ||||||
|  |         srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) | ||||||
|  |         log.debug("Opened socket '%s'", str(self._socket_path)) | ||||||
|  |         return srv | ||||||
|  |  | ||||||
|  |     def final_callback(self) -> None: | ||||||
|  |         self._socket_path.unlink() | ||||||
|  |         log.debug("Removed socket '%s'", str(self._socket_path)) | ||||||
| @@ -1,6 +1,9 @@ | |||||||
| from typing import Callable, Awaitable, Any | from asyncio.streams import StreamReader, StreamWriter | ||||||
|  | from typing import Tuple, Callable, Awaitable, Union, Any | ||||||
|  |  | ||||||
|  |  | ||||||
| CoroutineFunc = Callable[[...], Awaitable[Any]] | CoroutineFunc = Callable[[...], Awaitable[Any]] | ||||||
| FinalCallbackT = Callable | FinalCallbackT = Callable | ||||||
| CancelCallbackT = Callable | CancelCallbackT = Callable | ||||||
|  |  | ||||||
|  | ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] | ||||||
|   | |||||||
							
								
								
									
										0
									
								
								usage/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								usage/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										64
									
								
								usage/example_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								usage/example_server.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | |||||||
|  | import asyncio | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  | from asyncio_taskpool import TaskPool, UnixControlServer | ||||||
|  | from asyncio_taskpool.constants import PACKAGE_NAME | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logging.getLogger().setLevel(logging.NOTSET) | ||||||
|  | logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def work(item: int) -> None: | ||||||
|  |     """The non-blocking sleep simulates something like an I/O operation that can be done asynchronously.""" | ||||||
|  |     await asyncio.sleep(1) | ||||||
|  |     print("worked on", item) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def worker(q: asyncio.Queue) -> None: | ||||||
|  |     """Simulates doing asynchronous work that takes a little bit of time to finish.""" | ||||||
|  |     # We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop. | ||||||
|  |     while True: | ||||||
|  |         # We want to block here, until we can get the next item from the queue. | ||||||
|  |         item = await q.get() | ||||||
|  |         # Since we want a nice cleanup upon cancellation, we put the "work" to be done in a `try:` block. | ||||||
|  |         try: | ||||||
|  |             await work(item) | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             # If the task gets cancelled before our current "work" item is finished, we put it back into the queue | ||||||
|  |             # because a worker must assume that some other worker can and will eventually finish the work on that item. | ||||||
|  |             q.put_nowait(item) | ||||||
|  |             # This takes us out of the loop. To enable cleanup we must re-raise the exception. | ||||||
|  |             raise | ||||||
|  |         finally: | ||||||
|  |             # Since putting an item into the queue (even if it has just been taken out), increments the internal | ||||||
|  |             # `._unfinished_tasks` counter in the queue, we must ensure that it is decremented before we end the | ||||||
|  |             # iteration or leave the loop. Otherwise, the queue's `.join()` will block indefinitely. | ||||||
|  |             q.task_done() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def main() -> None: | ||||||
|  |     # First, we set up a queue of items that our workers can "work" on. | ||||||
|  |     q = asyncio.Queue() | ||||||
|  |     # We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit. | ||||||
|  |     for item in range(100): | ||||||
|  |         q.put_nowait(item) | ||||||
|  |     pool = TaskPool(worker, (q,))  # initializes the pool | ||||||
|  |     pool.start(3)  # launches three worker tasks | ||||||
|  |     control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() | ||||||
|  |     # We block until `.task_done()` has been called once by our workers for every item placed into the queue. | ||||||
|  |     await q.join() | ||||||
|  |     # Since we don't need any "work" done anymore, we can close our control server by cancelling the task. | ||||||
|  |     control_server_task.cancel() | ||||||
|  |     # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, | ||||||
|  |     # we can now safely cancel their tasks. | ||||||
|  |     pool.stop_all() | ||||||
|  |     # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. | ||||||
|  |     # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, | ||||||
|  |     # we just silently collect their exceptions along with their return values. | ||||||
|  |     await pool.gather(return_exceptions=True) | ||||||
|  |     await control_server_task | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     asyncio.run(main()) | ||||||
		Reference in New Issue
	
	Block a user