generated from daniil-berg/boilerplate-py
first working draft
This commit is contained in:
@ -0,0 +1 @@
|
||||
from .pool import TaskPool
|
||||
|
53
src/asyncio_taskpool/pool.py
Normal file
53
src/asyncio_taskpool/pool.py
Normal file
@ -0,0 +1,53 @@
|
||||
import logging
|
||||
from asyncio import gather
|
||||
from asyncio.tasks import Task
|
||||
from typing import Mapping, List, Iterable, Any
|
||||
|
||||
from .types import CoroutineFunc, FinalCallbackT, CancelCallbackT
|
||||
from .task import start_task
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskPool:
|
||||
def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None,
|
||||
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
|
||||
self._func: CoroutineFunc = func
|
||||
self._args: Iterable[Any] = args
|
||||
self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {}
|
||||
self._final_callback: FinalCallbackT = final_callback
|
||||
self._cancel_callback: CancelCallbackT = cancel_callback
|
||||
self._tasks: List[Task] = []
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
return self._func.__name__
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._tasks)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} func={self.func_name} size={self.size}>'
|
||||
|
||||
def _task_name(self, i: int) -> str:
|
||||
return f'{self.func_name}_pool_task_{i}'
|
||||
|
||||
def _start_one(self) -> None:
|
||||
self._tasks.append(start_task(self._func(*self._args, **self._kwargs), self._task_name(self.size),
|
||||
final_callback=self._final_callback, cancel_callback=self._cancel_callback))
|
||||
|
||||
def start(self, num: int = 1) -> None:
|
||||
for _ in range(num):
|
||||
self._start_one()
|
||||
|
||||
def stop(self, num: int = 1) -> int:
|
||||
if num < 1:
|
||||
return 0
|
||||
return sum(task.cancel() for task in reversed(self._tasks[-num:]))
|
||||
|
||||
async def gather(self, return_exceptions: bool = False):
|
||||
results = await gather(*self._tasks, return_exceptions=return_exceptions)
|
||||
self._tasks = []
|
||||
return results
|
30
src/asyncio_taskpool/task.py
Normal file
30
src/asyncio_taskpool/task.py
Normal file
@ -0,0 +1,30 @@
|
||||
import logging
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.tasks import Task, create_task
|
||||
from typing import Awaitable, Any
|
||||
|
||||
from .types import FinalCallbackT, CancelCallbackT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def wrap(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> Any:
|
||||
log.info("Started %s", task_name)
|
||||
try:
|
||||
return await awaitable
|
||||
except CancelledError:
|
||||
log.info("Cancelling %s ...", task_name)
|
||||
if callable(cancel_callback):
|
||||
cancel_callback()
|
||||
log.info("Cancelled %s", task_name)
|
||||
finally:
|
||||
if callable(final_callback):
|
||||
final_callback()
|
||||
log.info("Exiting %s", task_name)
|
||||
|
||||
|
||||
def start_task(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
|
||||
cancel_callback: CancelCallbackT = None) -> Task:
|
||||
return create_task(wrap(awaitable, task_name, final_callback, cancel_callback), name=task_name)
|
6
src/asyncio_taskpool/types.py
Normal file
6
src/asyncio_taskpool/types.py
Normal file
@ -0,0 +1,6 @@
|
||||
from typing import Callable, Awaitable, Any
|
||||
|
||||
|
||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
FinalCallbackT = Callable
|
||||
CancelCallbackT = Callable
|
Reference in New Issue
Block a user