asyncio-taskpool/src/asyncio_taskpool/pool.py

64 lines
2.1 KiB
Python
Raw Normal View History

2022-02-04 11:07:02 +01:00
import logging
from asyncio import gather
from asyncio.tasks import Task
from typing import Mapping, List, Iterable, Any
from .types import CoroutineFunc, FinalCallbackT, CancelCallbackT
from .task import start_task
log = logging.getLogger(__name__)
class TaskPool:
def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._func: CoroutineFunc = func
self._args: Iterable[Any] = args
self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {}
self._final_callback: FinalCallbackT = final_callback
self._cancel_callback: CancelCallbackT = cancel_callback
self._tasks: List[Task] = []
2022-02-04 16:25:09 +01:00
self._cancelled: List[Task] = []
2022-02-04 11:07:02 +01:00
@property
def func_name(self) -> str:
return self._func.__name__
@property
def size(self) -> int:
return len(self._tasks)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} func={self.func_name} size={self.size}>'
def _task_name(self, i: int) -> str:
return f'{self.func_name}_pool_task_{i}'
def _start_one(self) -> None:
self._tasks.append(start_task(self._func(*self._args, **self._kwargs), self._task_name(self.size),
final_callback=self._final_callback, cancel_callback=self._cancel_callback))
def start(self, num: int = 1) -> None:
for _ in range(num):
self._start_one()
def stop(self, num: int = 1) -> int:
2022-02-04 16:25:09 +01:00
for i in range(num):
try:
task = self._tasks.pop()
except IndexError:
num = i
break
task.cancel()
self._cancelled.append(task)
return num
def stop_all(self) -> int:
return self.stop(self.size)
2022-02-04 11:07:02 +01:00
async def gather(self, return_exceptions: bool = False):
2022-02-04 16:25:09 +01:00
results = await gather(*self._tasks, *self._cancelled, return_exceptions=return_exceptions)
self._tasks = self._cancelled = []
2022-02-04 11:07:02 +01:00
return results