huge rework; two different task pool classes now

This commit is contained in:
Daniil Fajnberg 2022-02-05 18:02:32 +01:00
parent f45fef6497
commit 3eae7d803f
8 changed files with 277 additions and 115 deletions

View File

@ -1,2 +1,2 @@
from .pool import TaskPool from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer from .server import UnixControlServer

View File

@ -0,0 +1,26 @@
class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
pass
class TaskEnded(PoolException):
pass
class AlreadyCancelled(TaskEnded):
pass
class AlreadyFinished(TaskEnded):
pass
class InvalidTaskID(PoolException):
pass
class PoolStillOpen(PoolException):
pass

View File

@ -1,36 +1,207 @@
import logging import logging
from asyncio import gather from asyncio import gather
from asyncio.tasks import Task from asyncio.coroutines import iscoroutinefunction
from typing import Mapping, List, Iterable, Any from asyncio.exceptions import CancelledError
from asyncio.locks import Event
from asyncio.tasks import Task, create_task
from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from .types import CoroutineFunc, FinalCallbackT, CancelCallbackT from . import exceptions
from .task import start_task from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TaskPool: class BaseTaskPool:
_pools: List['TaskPool'] = [] _pools: List['BaseTaskPool'] = []
@classmethod @classmethod
def _add_pool(cls, pool: 'TaskPool') -> int: def _add_pool(cls, pool: 'BaseTaskPool') -> int:
cls._pools.append(pool) cls._pools.append(pool)
return len(cls._pools) - 1 return len(cls._pools) - 1
def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None, def __init__(self, max_size: int = inf, name: str = None) -> None:
self._max_size: int = max_size # TODO: Make use of a synchronization primitive for this to work
self._open: bool = True
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
self._ended: Dict[int, Task] = {}
self._all_tasks_known: Event = Event()
self._all_tasks_known.set()
self._idx: int = self._add_pool(self)
self._name: str = name
log.debug("%s initialized", str(self))
def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def num_running(self) -> int:
return len(self._running)
@property
def num_cancelled(self) -> int:
return len(self._cancelled)
@property
def num_ended(self) -> int:
return len(self._ended)
@property
def num_finished(self) -> int:
return self.num_ended - self.num_cancelled
def _task_name(self, task_id: int) -> str:
return f'{self}_Task-{task_id}'
async def _cancel_task(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
log.debug("Cancelling %s ...", self._task_name(task_id))
task = self._running.pop(task_id)
assert task is not None
self._cancelled[task_id] = task
await _execute_function(custom_callback, args=(task_id, ))
log.debug("Cancelled %s", self._task_name(task_id))
async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None:
task = self._running.pop(task_id, None)
if task is None:
task = self._cancelled[task_id]
self._ended[task_id] = task
await _execute_function(custom_callback, args=(task_id, ))
log.info("Ended %s", self._task_name(task_id))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
log.info("Started %s", self._task_name(task_id))
try:
return await awaitable
except CancelledError:
await self._cancel_task(task_id, custom_callback=cancel_callback)
finally:
await self._end_task(task_id, custom_callback=final_callback)
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
if not (self._open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, final_callback, cancel_callback),
name=self._task_name(task_id)
)
return task_id
def _cancel_one(self, task_id: int, msg: str = None) -> None:
try:
task = self._running[task_id]
except KeyError:
if self._cancelled.get(task_id):
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
if self._ended.get(task_id):
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
task.cancel(msg=msg)
def cancel(self, *task_ids: int, msg: str = None) -> None:
for task_id in task_ids:
self._cancel_one(task_id, msg=msg)
def cancel_all(self, msg: str = None) -> None:
for task in self._running.values():
task.cancel(msg=msg)
def close(self) -> None:
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False):
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known.wait()
results = await gather(*self._running.values(), *self._ended.values(), return_exceptions=return_exceptions)
self._running = self._cancelled = self._ended = {}
return results
class TaskPool(BaseTaskPool):
def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
if kwargs is None:
kwargs = {}
return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback)
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num))
@staticmethod
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
try:
arg = next(args_iter)
except StopIteration:
return
if arg_stars == 0:
return func(arg)
if arg_stars == 1:
return func(*arg)
if arg_stars == 2:
return func(**arg)
raise ValueError
def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if self._all_tasks_known.is_set():
self._all_tasks_known.clear()
args_iter = iter(args_iter)
def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None:
self._all_tasks_known.set()
return True
self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _execute_function(final_callback, args=(task_id, ))
_start_next_coroutine()
for _ in range(num_tasks):
reached_end = _start_next_coroutine()
if reached_end:
break
def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None, final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: name: str = None) -> None:
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
self._args: Iterable[Any] = args self._args: ArgsT = args
self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._final_callback: FinalCallbackT = final_callback self._final_callback: FinalCallbackT = final_callback
self._cancel_callback: CancelCallbackT = cancel_callback self._cancel_callback: CancelCallbackT = cancel_callback
self._tasks: List[Task] = [] super().__init__(name=name)
self._cancelled: List[Task] = []
self._idx: int = self._add_pool(self)
self._name: str = name
log.debug("%s initialized", repr(self))
@property @property
def func_name(self) -> str: def func_name(self) -> str:
@ -38,40 +209,34 @@ class TaskPool:
@property @property
def size(self) -> int: def size(self) -> int:
return len(self._tasks) return self.num_running
def __str__(self) -> str: def _start_one(self) -> int:
return f'{self.__class__.__name__}-{self._name or self._idx}' return self._start_task(self._func(*self._args, **self._kwargs),
final_callback=self._final_callback, cancel_callback=self._cancel_callback)
def __repr__(self) -> str: def start(self, num: int = 1) -> List[int]:
return f'<{self} func={self.func_name}>' return [self._start_one() for _ in range(num)]
def _task_name(self, i: int) -> str: def stop(self, num: int = 1) -> List[int]:
return f'{self.func_name}_pool_task_{i}' num = min(num, self.size)
ids = []
def _start_one(self) -> None: for i, task_id in enumerate(reversed(self._running)):
self._tasks.append(start_task(self._func(*self._args, **self._kwargs), self._task_name(self.size), if i >= num:
final_callback=self._final_callback, cancel_callback=self._cancel_callback))
def start(self, num: int = 1) -> None:
for _ in range(num):
self._start_one()
def stop(self, num: int = 1) -> int:
for i in range(num):
try:
task = self._tasks.pop()
except IndexError:
num = i
break break
task.cancel() ids.append(task_id)
self._cancelled.append(task) self.cancel(*ids)
return num return ids
def stop_all(self) -> int: def stop_all(self) -> List[int]:
return self.stop(self.size) return self.stop(self.size)
async def close(self, return_exceptions: bool = False):
results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions) async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
self._tasks = self._cancelled = [] if kwargs is None:
return results kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -8,7 +8,7 @@ from pathlib import Path
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional
from . import constants from . import constants
from .pool import TaskPool from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient from .client import ControlClient, UnixControlClient
@ -29,7 +29,7 @@ def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]
return cmd[0], None return cmd[0], None
class ControlServer(ABC): class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient client_class = ControlClient
@abstractmethod @abstractmethod
@ -40,8 +40,8 @@ class ControlServer(ABC):
def final_callback(self) -> None: def final_callback(self) -> None:
raise NotImplementedError raise NotImplementedError
def __init__(self, pool: TaskPool, **server_kwargs) -> None: def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._pool: TaskPool = pool self._pool: SimpleTaskPool = pool
self._server_kwargs = server_kwargs self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None self._server: Optional[AbstractServer] = None
@ -49,26 +49,22 @@ class ControlServer(ABC):
if num is None: if num is None:
num = 1 num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num)) log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
self._pool.start(num) writer.write(str(self._pool.start(num)).encode())
size = self._pool.size
writer.write(f"{num} new {tasks_str(num)} started! {size} {tasks_str(size)} active now.".encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None: def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None: if num is None:
num = 1 num = 1
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num)) log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
num = self._pool.stop(num) # the requested number may be greater than the total number of running tasks # the requested number may be greater than the total number of running tasks
size = self._pool.size writer.write(str(self._pool.stop(num)).encode())
writer.write(f"{num} {tasks_str(num)} stopped! {size} {tasks_str(size)} left.".encode())
def _stop_all_tasks(self, writer: StreamWriter) -> None: def _stop_all_tasks(self, writer: StreamWriter) -> None:
log.debug("%s requests stopping all tasks", self.client_class.__name__) log.debug("%s requests stopping all tasks", self.client_class.__name__)
num = self._pool.stop_all() writer.write(str(self._pool.stop_all()).encode())
writer.write(f"Remaining {num} {tasks_str(num)} stopped!".encode())
def _pool_size(self, writer: StreamWriter) -> None: def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests pool size", self.client_class.__name__) log.debug("%s requests pool size", self.client_class.__name__)
writer.write(f'{self._pool.size}'.encode()) writer.write(str(self._pool.size).encode())
def _pool_func(self, writer: StreamWriter) -> None: def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__) log.debug("%s requests pool function", self.client_class.__name__)
@ -98,7 +94,7 @@ class ControlServer(ABC):
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__) log.debug("%s connected", self.client_class.__name__)
writer.write(f"{self.__class__.__name__} for {self._pool}".encode()) writer.write(str(self._pool).encode())
await writer.drain() await writer.drain()
await self._listen(reader, writer) await self._listen(reader, writer)
@ -120,7 +116,7 @@ class ControlServer(ABC):
class UnixControlServer(ControlServer): class UnixControlServer(ControlServer):
client_class = UnixControlClient client_class = UnixControlClient
def __init__(self, pool: TaskPool, **server_kwargs) -> None: def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path')) self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs) super().__init__(pool, **server_kwargs)

View File

@ -1,30 +0,0 @@
import logging
from asyncio.exceptions import CancelledError
from asyncio.tasks import Task, create_task
from typing import Awaitable, Any
from .types import FinalCallbackT, CancelCallbackT
log = logging.getLogger(__name__)
async def wrap(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
log.info("Started %s", task_name)
try:
return await awaitable
except CancelledError:
log.info("Cancelling %s ...", task_name)
if callable(cancel_callback):
cancel_callback()
log.info("Cancelled %s", task_name)
finally:
if callable(final_callback):
final_callback()
log.info("Exiting %s", task_name)
def start_task(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Task:
return create_task(wrap(awaitable, task_name, final_callback, cancel_callback), name=task_name)

View File

@ -1,7 +1,9 @@
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Tuple, Callable, Awaitable, Union, Any from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
FinalCallbackT = Callable FinalCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable

View File

@ -1,8 +1,8 @@
# Using `asyncio-taskpool` # Using `asyncio-taskpool`
## Simple example ## Minimal example for `SimpleTaskPool`
The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `TaskPool` and starts/stops the tasks as desired, eventually awaiting them all. The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `SimpleTaskPool` and starts/stops the tasks as desired, eventually awaiting them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is. The following demo code enables full log output first for additional clarity. It is complete and should work as is.
@ -11,7 +11,7 @@ The following demo code enables full log output first for additional clarity. It
import logging import logging
import asyncio import asyncio
from asyncio_taskpool.pool import TaskPool from asyncio_taskpool.pool import SimpleTaskPool
logging.getLogger().setLevel(logging.NOTSET) logging.getLogger().setLevel(logging.NOTSET)
@ -32,13 +32,14 @@ async def work(n: int) -> None:
async def main() -> None: async def main() -> None:
pool = TaskPool(work, (5,)) # initializes the pool; no work is being done yet pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
pool.start(3) # launches work tasks 0, 1, and 2 pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.start() # launches work task 3 pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 pool.stop(2) # cancels tasks 3 and 2
await pool.close() # awaits all tasks, then flushes the pool pool.close() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool
if __name__ == '__main__': if __name__ == '__main__':
@ -46,31 +47,32 @@ if __name__ == '__main__':
``` ```
### Output ### Output
Additional comments indicated with `<--`
``` ```
Started work_pool_task_0 SimpleTaskPool-0 initialized
Started work_pool_task_1 Started SimpleTaskPool-0_Task-0
Started work_pool_task_2 Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2
did 0 did 0
did 0 did 0
did 0 did 0
Started work_pool_task_3 Started SimpleTaskPool-0_Task-3
did 1 did 1
did 1 did 1
did 1 did 1
did 0 <-- notice that the newly created task begins counting at 0 did 0
SimpleTaskPool-0 is closed!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2
did 2
did 2 did 2
did 2 <-- two taks were stopped; only tasks 0 and 1 continue "working"
Cancelling work_pool_task_2 ...
Cancelled work_pool_task_2
Exiting work_pool_task_2
Cancelling work_pool_task_3 ...
Cancelled work_pool_task_3
Exiting work_pool_task_3
did 3 did 3
did 3 did 3
Exiting work_pool_task_0 Ended SimpleTaskPool-0_Task-0
Exiting work_pool_task_1 Ended SimpleTaskPool-0_Task-1
did 4 did 4
did 4 did 4
``` ```

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from asyncio_taskpool import TaskPool, UnixControlServer from asyncio_taskpool import SimpleTaskPool, UnixControlServer
from asyncio_taskpool.constants import PACKAGE_NAME from asyncio_taskpool.constants import PACKAGE_NAME
@ -43,7 +43,7 @@ async def main() -> None:
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit. # We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
for item in range(100): for item in range(100):
q.put_nowait(item) q.put_nowait(item)
pool = TaskPool(worker, (q,)) # initializes the pool pool = SimpleTaskPool(worker, (q,)) # initializes the pool
pool.start(3) # launches three worker tasks pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue. # We block until `.task_done()` has been called once by our workers for every item placed into the queue.
@ -53,10 +53,11 @@ async def main() -> None:
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.stop_all() pool.stop_all()
pool.close()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values. # we just silently collect their exceptions along with their return values.
await pool.close(return_exceptions=True) await pool.gather(return_exceptions=True)
await control_server_task await control_server_task