From 3eae7d803f9e9e5daa02c018bf5a41aca53b12ad Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sat, 5 Feb 2022 18:02:32 +0100 Subject: [PATCH] huge rework; two different task pool classes now --- src/asyncio_taskpool/__init__.py | 2 +- src/asyncio_taskpool/exceptions.py | 26 +++ src/asyncio_taskpool/pool.py | 255 ++++++++++++++++++++++++----- src/asyncio_taskpool/server.py | 26 ++- src/asyncio_taskpool/task.py | 30 ---- src/asyncio_taskpool/types.py | 4 +- usage/USAGE.md | 42 ++--- usage/example_server.py | 7 +- 8 files changed, 277 insertions(+), 115 deletions(-) create mode 100644 src/asyncio_taskpool/exceptions.py delete mode 100644 src/asyncio_taskpool/task.py diff --git a/src/asyncio_taskpool/__init__.py b/src/asyncio_taskpool/__init__.py index d2f82a0..5cef20b 100644 --- a/src/asyncio_taskpool/__init__.py +++ b/src/asyncio_taskpool/__init__.py @@ -1,2 +1,2 @@ -from .pool import TaskPool +from .pool import TaskPool, SimpleTaskPool from .server import UnixControlServer diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py new file mode 100644 index 0000000..9171b26 --- /dev/null +++ b/src/asyncio_taskpool/exceptions.py @@ -0,0 +1,26 @@ +class PoolException(Exception): + pass + + +class PoolIsClosed(PoolException): + pass + + +class TaskEnded(PoolException): + pass + + +class AlreadyCancelled(TaskEnded): + pass + + +class AlreadyFinished(TaskEnded): + pass + + +class InvalidTaskID(PoolException): + pass + + +class PoolStillOpen(PoolException): + pass diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index f26bfc8..6e47c63 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -1,36 +1,207 @@ import logging from asyncio import gather -from asyncio.tasks import Task -from typing import Mapping, List, Iterable, Any +from asyncio.coroutines import iscoroutinefunction +from asyncio.exceptions import CancelledError +from asyncio.locks import Event +from asyncio.tasks import Task, create_task +from math import inf +from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple -from .types import CoroutineFunc, FinalCallbackT, CancelCallbackT -from .task import start_task +from . import exceptions +from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT log = logging.getLogger(__name__) -class TaskPool: - _pools: List['TaskPool'] = [] +class BaseTaskPool: + _pools: List['BaseTaskPool'] = [] @classmethod - def _add_pool(cls, pool: 'TaskPool') -> int: + def _add_pool(cls, pool: 'BaseTaskPool') -> int: cls._pools.append(pool) return len(cls._pools) - 1 - def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None, + def __init__(self, max_size: int = inf, name: str = None) -> None: + self._max_size: int = max_size # TODO: Make use of a synchronization primitive for this to work + self._open: bool = True + self._counter: int = 0 + self._running: Dict[int, Task] = {} + self._cancelled: Dict[int, Task] = {} + self._ended: Dict[int, Task] = {} + self._all_tasks_known: Event = Event() + self._all_tasks_known.set() + self._idx: int = self._add_pool(self) + self._name: str = name + log.debug("%s initialized", str(self)) + + def __str__(self) -> str: + return f'{self.__class__.__name__}-{self._name or self._idx}' + + @property + def num_running(self) -> int: + return len(self._running) + + @property + def num_cancelled(self) -> int: + return len(self._cancelled) + + @property + def num_ended(self) -> int: + return len(self._ended) + + @property + def num_finished(self) -> int: + return self.num_ended - self.num_cancelled + + def _task_name(self, task_id: int) -> str: + return f'{self}_Task-{task_id}' + + async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None: + log.debug("Cancelling %s ...", self._task_name(task_id)) + task = self._running.pop(task_id) + assert task is not None + self._cancelled[task_id] = task + await _execute_function(custom_callback, args=(task_id, )) + log.debug("Cancelled %s", self._task_name(task_id)) + + async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None: + task = self._running.pop(task_id, None) + if task is None: + task = self._cancelled[task_id] + self._ended[task_id] = task + await _execute_function(custom_callback, args=(task_id, )) + log.info("Ended %s", self._task_name(task_id)) + + async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None, + cancel_callback: CancelCallbackT = None) -> Any: + log.info("Started %s", self._task_name(task_id)) + try: + return await awaitable + except CancelledError: + await self._cancel_task(task_id, custom_callback=cancel_callback) + finally: + await self._end_task(task_id, custom_callback=final_callback) + + def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None, + cancel_callback: CancelCallbackT = None) -> int: + if not (self._open or ignore_closed): + raise exceptions.PoolIsClosed("Cannot start new tasks") + task_id = self._counter + self._counter += 1 + self._running[task_id] = create_task( + self._task_wrapper(awaitable, task_id, final_callback, cancel_callback), + name=self._task_name(task_id) + ) + return task_id + + def _cancel_one(self, task_id: int, msg: str = None) -> None: + try: + task = self._running[task_id] + except KeyError: + if self._cancelled.get(task_id): + raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled") + if self._ended.get(task_id): + raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running") + raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") + task.cancel(msg=msg) + + def cancel(self, *task_ids: int, msg: str = None) -> None: + for task_id in task_ids: + self._cancel_one(task_id, msg=msg) + + def cancel_all(self, msg: str = None) -> None: + for task in self._running.values(): + task.cancel(msg=msg) + + def close(self) -> None: + self._open = False + log.info("%s is closed!", str(self)) + + async def gather(self, return_exceptions: bool = False): + if self._open: + raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered") + await self._all_tasks_known.wait() + results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions) + self._running = self._cancelled = self._ended = {} + return results + + +class TaskPool(BaseTaskPool): + def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int: + if kwargs is None: + kwargs = {} + return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback) + + def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: + return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num)) + + @staticmethod + def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: + try: + arg = next(args_iter) + except StopIteration: + return + if arg_stars == 0: + return func(arg) + if arg_stars == 1: + return func(*arg) + if arg_stars == 2: + return func(**arg) + raise ValueError + + def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + + if self._all_tasks_known.is_set(): + self._all_tasks_known.clear() + args_iter = iter(args_iter) + + def _start_next_coroutine() -> bool: + cor = self._get_next_coroutine(func, args_iter, arg_stars) + if cor is None: + self._all_tasks_known.set() + return True + self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback) + return False + + async def _start_next(task_id: int) -> None: + await _execute_function(final_callback, args=(task_id, )) + _start_next_coroutine() + + for _ in range(num_tasks): + reached_end = _start_next_coroutine() + if reached_end: + break + + def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks, + final_callback=final_callback, cancel_callback=cancel_callback) + + def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks, + final_callback=final_callback, cancel_callback=cancel_callback) + + def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1, + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, + final_callback=final_callback, cancel_callback=cancel_callback) + + +class SimpleTaskPool(BaseTaskPool): + def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None, name: str = None) -> None: self._func: CoroutineFunc = func - self._args: Iterable[Any] = args - self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {} + self._args: ArgsT = args + self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._final_callback: FinalCallbackT = final_callback self._cancel_callback: CancelCallbackT = cancel_callback - self._tasks: List[Task] = [] - self._cancelled: List[Task] = [] - self._idx: int = self._add_pool(self) - self._name: str = name - log.debug("%s initialized", repr(self)) + super().__init__(name=name) @property def func_name(self) -> str: @@ -38,40 +209,34 @@ class TaskPool: @property def size(self) -> int: - return len(self._tasks) + return self.num_running - def __str__(self) -> str: - return f'{self.__class__.__name__}-{self._name or self._idx}' + def _start_one(self) -> int: + return self._start_task(self._func(*self._args, **self._kwargs), + final_callback=self._final_callback, cancel_callback=self._cancel_callback) - def __repr__(self) -> str: - return f'<{self} func={self.func_name}>' + def start(self, num: int = 1) -> List[int]: + return [self._start_one() for _ in range(num)] - def _task_name(self, i: int) -> str: - return f'{self.func_name}_pool_task_{i}' - - def _start_one(self) -> None: - self._tasks.append(start_task(self._func(*self._args, **self._kwargs), self._task_name(self.size), - final_callback=self._final_callback, cancel_callback=self._cancel_callback)) - - def start(self, num: int = 1) -> None: - for _ in range(num): - self._start_one() - - def stop(self, num: int = 1) -> int: - for i in range(num): - try: - task = self._tasks.pop() - except IndexError: - num = i + def stop(self, num: int = 1) -> List[int]: + num = min(num, self.size) + ids = [] + for i, task_id in enumerate(reversed(self._running)): + if i >= num: break - task.cancel() - self._cancelled.append(task) - return num + ids.append(task_id) + self.cancel(*ids) + return ids - def stop_all(self) -> int: + def stop_all(self) -> List[int]: return self.stop(self.size) - async def close(self, return_exceptions: bool = False): - results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions) - self._tasks = self._cancelled = [] - return results + +async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None: + if kwargs is None: + kwargs = {} + if callable(func): + if iscoroutinefunction(func): + await func(*args, **kwargs) + else: + func(*args, **kwargs) diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 7f57ac3..421e2b3 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Tuple, Union, Optional from . import constants -from .pool import TaskPool +from .pool import SimpleTaskPool from .client import ControlClient, UnixControlClient @@ -29,7 +29,7 @@ def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]] return cmd[0], None -class ControlServer(ABC): +class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool client_class = ControlClient @abstractmethod @@ -40,8 +40,8 @@ class ControlServer(ABC): def final_callback(self) -> None: raise NotImplementedError - def __init__(self, pool: TaskPool, **server_kwargs) -> None: - self._pool: TaskPool = pool + def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: + self._pool: SimpleTaskPool = pool self._server_kwargs = server_kwargs self._server: Optional[AbstractServer] = None @@ -49,26 +49,22 @@ class ControlServer(ABC): if num is None: num = 1 log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) - self._pool.start(num) - size = self._pool.size - writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode()) + writer.write(str(self._pool.start(num)).encode()) def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: if num is None: num = 1 log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num)) - num = self._pool.stop(num) # the requested number may be greater than the total number of running tasks - size = self._pool.size - writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode()) + # the requested number may be greater than the total number of running tasks + writer.write(str(self._pool.stop(num)).encode()) def _stop_all_tasks(self, writer: StreamWriter) -> None: log.debug("%s requests stopping all tasks", self.client_class.__name__) - num = self._pool.stop_all() - writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode()) + writer.write(str(self._pool.stop_all()).encode()) def _pool_size(self, writer: StreamWriter) -> None: log.debug("%s requests pool size", self.client_class.__name__) - writer.write(f'{self._pool.size}'.encode()) + writer.write(str(self._pool.size).encode()) def _pool_func(self, writer: StreamWriter) -> None: log.debug("%s requests pool function", self.client_class.__name__) @@ -98,7 +94,7 @@ class ControlServer(ABC): async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: log.debug("%s connected", self.client_class.__name__) - writer.write(f"{self.__class__.__name__} for {self._pool}".encode()) + writer.write(str(self._pool).encode()) await writer.drain() await self._listen(reader, writer) @@ -120,7 +116,7 @@ class ControlServer(ABC): class UnixControlServer(ControlServer): client_class = UnixControlClient - def __init__(self, pool: TaskPool, **server_kwargs) -> None: + def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: self._socket_path = Path(server_kwargs.pop('path')) super().__init__(pool, **server_kwargs) diff --git a/src/asyncio_taskpool/task.py b/src/asyncio_taskpool/task.py deleted file mode 100644 index d006212..0000000 --- a/src/asyncio_taskpool/task.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -from asyncio.exceptions import CancelledError -from asyncio.tasks import Task, create_task -from typing import Awaitable, Any - -from .types import FinalCallbackT, CancelCallbackT - - -log = logging.getLogger(__name__) - - -async def wrap(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None, - cancel_callback: CancelCallbackT = None) -> Any: - log.info("Started %s", task_name) - try: - return await awaitable - except CancelledError: - log.info("Cancelling %s ...", task_name) - if callable(cancel_callback): - cancel_callback() - log.info("Cancelled %s", task_name) - finally: - if callable(final_callback): - final_callback() - log.info("Exiting %s", task_name) - - -def start_task(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None, - cancel_callback: CancelCallbackT = None) -> Task: - return create_task(wrap(awaitable, task_name, final_callback, cancel_callback), name=task_name) diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index 94a0589..2caf9a6 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -1,7 +1,9 @@ from asyncio.streams import StreamReader, StreamWriter -from typing import Tuple, Callable, Awaitable, Union, Any +from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union +ArgsT = Iterable[Any] +KwArgsT = Mapping[str, Any] CoroutineFunc = Callable[[...], Awaitable[Any]] FinalCallbackT = Callable CancelCallbackT = Callable diff --git a/usage/USAGE.md b/usage/USAGE.md index e9b9aab..8383329 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -1,8 +1,8 @@ # Using `asyncio-taskpool` -## Simple example +## Minimal example for `SimpleTaskPool` -The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `TaskPool` and starts/stops the tasks as desired, eventually awaiting them all. +The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `SimpleTaskPool` and starts/stops the tasks as desired, eventually awaiting them all. The following demo code enables full log output first for additional clarity. It is complete and should work as is. @@ -11,7 +11,7 @@ The following demo code enables full log output first for additional clarity. It import logging import asyncio -from asyncio_taskpool.pool import TaskPool +from asyncio_taskpool.pool import SimpleTaskPool logging.getLogger().setLevel(logging.NOTSET) @@ -32,13 +32,14 @@ async def work(n: int) -> None: async def main() -> None: - pool = TaskPool(work, (5,)) # initializes the pool; no work is being done yet + pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet pool.start(3) # launches work tasks 0, 1, and 2 await asyncio.sleep(1.5) # lets the tasks work for a bit pool.start() # launches work task 3 await asyncio.sleep(1.5) # lets the tasks work for a bit pool.stop(2) # cancels tasks 3 and 2 - await pool.close() # awaits all tasks, then flushes the pool + pool.close() # required for the last line + await pool.gather() # awaits all tasks, then flushes the pool if __name__ == '__main__': @@ -46,31 +47,32 @@ if __name__ == '__main__': ``` ### Output -Additional comments indicated with `<--` ``` -Started work_pool_task_0 -Started work_pool_task_1 -Started work_pool_task_2 +SimpleTaskPool-0 initialized +Started SimpleTaskPool-0_Task-0 +Started SimpleTaskPool-0_Task-1 +Started SimpleTaskPool-0_Task-2 did 0 did 0 did 0 -Started work_pool_task_3 +Started SimpleTaskPool-0_Task-3 did 1 did 1 did 1 -did 0 <-- notice that the newly created task begins counting at 0 +did 0 +SimpleTaskPool-0 is closed! +Cancelling SimpleTaskPool-0_Task-3 ... +Cancelled SimpleTaskPool-0_Task-3 +Ended SimpleTaskPool-0_Task-3 +Cancelling SimpleTaskPool-0_Task-2 ... +Cancelled SimpleTaskPool-0_Task-2 +Ended SimpleTaskPool-0_Task-2 +did 2 did 2 -did 2 <-- two taks were stopped; only tasks 0 and 1 continue "working" -Cancelling work_pool_task_2 ... -Cancelled work_pool_task_2 -Exiting work_pool_task_2 -Cancelling work_pool_task_3 ... -Cancelled work_pool_task_3 -Exiting work_pool_task_3 did 3 did 3 -Exiting work_pool_task_0 -Exiting work_pool_task_1 +Ended SimpleTaskPool-0_Task-0 +Ended SimpleTaskPool-0_Task-1 did 4 did 4 ``` diff --git a/usage/example_server.py b/usage/example_server.py index 08dfb26..5acc831 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -1,7 +1,7 @@ import asyncio import logging -from asyncio_taskpool import TaskPool, UnixControlServer +from asyncio_taskpool import SimpleTaskPool, UnixControlServer from asyncio_taskpool.constants import PACKAGE_NAME @@ -43,7 +43,7 @@ async def main() -> None: # We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit. for item in range(100): q.put_nowait(item) - pool = TaskPool(worker, (q,)) # initializes the pool + pool = SimpleTaskPool(worker, (q,)) # initializes the pool pool.start(3) # launches three worker tasks control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() # We block until `.task_done()` has been called once by our workers for every item placed into the queue. @@ -53,10 +53,11 @@ async def main() -> None: # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # we can now safely cancel their tasks. pool.stop_all() + pool.close() # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # we just silently collect their exceptions along with their return values. - await pool.close(return_exceptions=True) + await pool.gather(return_exceptions=True) await control_server_task