Compare commits

...

14 Commits

14 changed files with 819 additions and 123 deletions

View File

@ -4,66 +4,7 @@ Dynamically manage pools of asyncio tasks
## Usage
Demo:
```python
import logging
import asyncio
from asyncio_taskpool.pool import TaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
async def work(n):
for i in range(n):
await asyncio.sleep(1)
print("did", i)
async def main():
pool = TaskPool(work, (5,)) # initializes the pool
pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5)
pool.start() # launches work task 3
await asyncio.sleep(1.5)
pool.stop(2) # cancels tasks 3 and 2
await pool.gather() # awaits all tasks, then flushes the pool
if __name__ == '__main__':
asyncio.run(main())
```
Output:
```
Started work_pool_task_0
Started work_pool_task_1
Started work_pool_task_2
did 0
did 0
did 0
Started work_pool_task_3
did 1
did 1
did 1
did 0
did 2
did 2
Cancelling work_pool_task_2 ...
Cancelled work_pool_task_2
Exiting work_pool_task_2
Cancelling work_pool_task_3 ...
Cancelled work_pool_task_3
Exiting work_pool_task_3
did 3
did 3
Exiting work_pool_task_0
Exiting work_pool_task_1
did 4
did 4
```
See [USAGE.md](usage/USAGE.md)
## Installation

View File

@ -1 +1,2 @@
from .pool import TaskPool
from .pool import TaskPool, SimpleTaskPool
from .server import UnixControlServer

View File

@ -0,0 +1,46 @@
import sys
from argparse import ArgumentParser
from asyncio import run
from pathlib import Path
from typing import Dict, Any
from .client import ControlClient, UnixControlClient
from .constants import PACKAGE_NAME
from .pool import TaskPool
from .server import ControlServer
CONN_TYPE = 'conn_type'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
def parse_cli() -> Dict[str, Any]:
parser = ArgumentParser(
prog=PACKAGE_NAME,
description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}"
)
subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE)
unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket")
unix_parser.add_argument(
SOCKET_PATH,
type=Path,
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
)
return vars(parser.parse_args())
async def main():
kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class
client = UnixControlClient(path=kwargs[SOCKET_PATH])
else:
print("Invalid connection type", file=sys.stderr)
sys.exit(2)
await client.start()
if __name__ == '__main__':
run(main())

View File

@ -0,0 +1,63 @@
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from pathlib import Path
from asyncio_taskpool import constants
from asyncio_taskpool.types import ClientConnT
class ControlClient(ABC):
@abstractmethod
async def open_connection(self, **kwargs) -> ClientConnT:
raise NotImplementedError
def __init__(self, **conn_kwargs) -> None:
self._conn_kwargs = conn_kwargs
self._connected: bool = False
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
try:
msg = input("> ").strip().lower()
except EOFError:
msg = constants.CLIENT_EXIT
except KeyboardInterrupt:
print()
return
if msg == constants.CLIENT_EXIT:
writer.close()
self._connected = False
return
try:
writer.write(msg.encode())
await writer.drain()
except ConnectionError as e:
self._connected = False
print(e, file=sys.stderr)
return
print((await reader.read(constants.MSG_BYTES)).decode())
async def start(self):
reader, writer = await self.open_connection(**self._conn_kwargs)
if reader is None:
print("Failed to connect.", file=sys.stderr)
return
self._connected = True
print("Connected to", (await reader.read(constants.MSG_BYTES)).decode())
while self._connected:
await self._interact(reader, writer)
print("Disconnected from control server.")
class UnixControlClient(ControlClient):
def __init__(self, **conn_kwargs) -> None:
self._socket_path = Path(conn_kwargs.pop('path'))
super().__init__(**conn_kwargs)
async def open_connection(self, **kwargs) -> ClientConnT:
try:
return await open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr)
return None, None

View File

@ -0,0 +1,8 @@
PACKAGE_NAME = 'asyncio_taskpool'
MSG_BYTES = 1024
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all'
CMD_SIZE = 'size'
CMD_FUNC = 'func'
CLIENT_EXIT = 'exit'

View File

@ -0,0 +1,26 @@
class PoolException(Exception):
pass
class PoolIsClosed(PoolException):
pass
class TaskEnded(PoolException):
pass
class AlreadyCancelled(TaskEnded):
pass
class AlreadyFinished(TaskEnded):
pass
class InvalidTaskID(PoolException):
pass
class PoolStillOpen(PoolException):
pass

View File

@ -1,24 +1,272 @@
import logging
from asyncio import gather
from asyncio.tasks import Task
from typing import Mapping, List, Iterable, Any
from asyncio.coroutines import iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task
from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from .types import CoroutineFunc, FinalCallbackT, CancelCallbackT
from .task import start_task
from . import exceptions
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
log = logging.getLogger(__name__)
class TaskPool:
def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
class BaseTaskPool:
"""The base class for task pools. Not intended to be used directly."""
_pools: List['BaseTaskPool'] = []
@classmethod
def _add_pool(cls, pool: 'BaseTaskPool') -> int:
"""Adds a `pool` (instance of any subclass) to the general list of pools and returns it's index in the list."""
cls._pools.append(pool)
return len(cls._pools) - 1
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True
self._counter: int = 0
self._running: Dict[int, Task] = {}
self._cancelled: Dict[int, Task] = {}
self._ended: Dict[int, Task] = {}
self._num_cancelled: int = 0
self._num_ended: int = 0
self._idx: int = self._add_pool(self)
self._name: str = name
self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set()
log.debug("%s initialized", str(self))
def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def pool_size(self) -> int:
return self._pool_size
@pool_size.setter
def pool_size(self, value: int) -> None:
if value < 0:
raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value
self._pool_size = value
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
return self._open
@property
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's `end_callback` is fired, it is no longer considered to be running.
"""
return len(self._running)
@property
def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
"""
return self._num_cancelled
@property
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's `end_callback` is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end.
"""
return self._num_ended
@property
def num_finished(self) -> int:
"""
Returns the number of tasks in the pool that have actually finished running (without having been cancelled).
"""
return self._num_ended - self._num_cancelled + len(self._cancelled)
@property
def is_full(self) -> bool:
"""
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
"""
return self._enough_room.locked()
# TODO: Consider adding task group names
def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}'
async def _task_cancellation(self, task_id: int, custom_callback: CancelCallbackT = None) -> None:
log.debug("Cancelling %s ...", self._task_name(task_id))
self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
try:
self._ended[task_id] = self._running.pop(task_id)
except KeyError:
self._ended[task_id] = self._cancelled.pop(task_id)
self._num_ended += 1
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
log.info("Started %s", self._task_name(task_id))
try:
return await awaitable
except CancelledError:
await self._task_cancellation(task_id, custom_callback=cancel_callback)
finally:
await self._task_ending(task_id, custom_callback=end_callback)
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire()
try:
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id)
)
except Exception as e:
self._enough_room.release()
raise e
return task_id
def _get_running_task(self, task_id: int) -> Task:
try:
return self._running[task_id]
except KeyError:
if self._cancelled.get(task_id):
raise exceptions.AlreadyCancelled(f"{self._task_name(task_id)} has already been cancelled")
if self._ended.get(task_id):
raise exceptions.AlreadyFinished(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
def _cancel_task(self, task_id: int, msg: str = None) -> None:
self._get_running_task(task_id).cancel(msg=msg)
def cancel(self, *task_ids: int, msg: str = None) -> None:
tasks = [self._get_running_task(task_id) for task_id in task_ids]
for task in tasks:
task.cancel(msg=msg)
async def cancel_all(self, msg: str = None) -> None:
await self._all_tasks_known_flag.wait()
for task in self._running.values():
task.cancel(msg=msg)
async def flush(self, return_exceptions: bool = False):
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {}
return results
def close(self) -> None:
self._open = False
log.info("%s is closed!", str(self))
async def gather(self, return_exceptions: bool = False):
if self._open:
raise exceptions.PoolStillOpen("Pool must be closed, before tasks can be gathered")
await self._all_tasks_known_flag.wait()
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {}
return results
class TaskPool(BaseTaskPool):
async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
if kwargs is None:
kwargs = {}
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))
@staticmethod
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
try:
arg = next(args_iter)
except StopIteration:
return
if arg_stars == 0:
return func(arg)
if arg_stars == 1:
return func(*arg)
if arg_stars == 2:
return func(**arg)
raise ValueError
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear()
args_iter = iter(args_iter)
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None:
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _start_next_coroutine()
await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks):
reached_end = await _start_next_coroutine()
if reached_end:
break
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None:
self._func: CoroutineFunc = func
self._args: Iterable[Any] = args
self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {}
self._final_callback: FinalCallbackT = final_callback
self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback
self._tasks: List[Task] = []
super().__init__(name=name)
@property
def func_name(self) -> str:
@ -26,28 +274,34 @@ class TaskPool:
@property
def size(self) -> int:
return len(self._tasks)
return self.num_running
def __repr__(self) -> str:
return f'<{self.__class__.__name__} func={self.func_name} size={self.size}>'
async def _start_one(self) -> int:
return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
def _task_name(self, i: int) -> str:
return f'{self.func_name}_pool_task_{i}'
async def start(self, num: int = 1) -> List[int]:
return [await self._start_one() for _ in range(num)]
def _start_one(self) -> None:
self._tasks.append(start_task(self._func(*self._args, **self._kwargs), self._task_name(self.size),
final_callback=self._final_callback, cancel_callback=self._cancel_callback))
def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size)
ids = []
for i, task_id in enumerate(reversed(self._running)):
if i >= num:
break
ids.append(task_id)
self.cancel(*ids)
return ids
def start(self, num: int = 1) -> None:
for _ in range(num):
self._start_one()
def stop_all(self) -> List[int]:
return self.stop(self.size)
def stop(self, num: int = 1) -> int:
if num < 1:
return 0
return sum(task.cancel() for task in reversed(self._tasks[-num:]))
async def gather(self, return_exceptions: bool = False):
results = await gather(*self._tasks, return_exceptions=return_exceptions)
self._tasks = []
return results
async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
if kwargs is None:
kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -0,0 +1,130 @@
import logging
from abc import ABC, abstractmethod
from asyncio import AbstractServer
from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Tuple, Union, Optional
from . import constants
from .pool import SimpleTaskPool
from .client import ControlClient, UnixControlClient
log = logging.getLogger(__name__)
def tasks_str(num: int) -> str:
return "tasks" if num != 1 else "task"
def get_cmd_arg(msg: str) -> Union[Tuple[str, Optional[int]], Tuple[None, None]]:
cmd = msg.strip().split(' ', 1)
if len(cmd) > 1:
try:
return cmd[0], int(cmd[1])
except ValueError:
return None, None
return cmd[0], None
class ControlServer(ABC): # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool
client_class = ControlClient
@abstractmethod
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
raise NotImplementedError
@abstractmethod
def final_callback(self) -> None:
raise NotImplementedError
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._pool: SimpleTaskPool = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(await self._pool.start(num)).encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests stopping %s %s", self.client_class.__name__, num, tasks_str(num))
# the requested number may be greater than the total number of running tasks
writer.write(str(self._pool.stop(num)).encode())
def _stop_all_tasks(self, writer: StreamWriter) -> None:
log.debug("%s requests stopping all tasks", self.client_class.__name__)
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests pool size", self.client_class.__name__)
writer.write(str(self._pool.size).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
writer.write(self._pool.func_name.encode())
async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None:
while self._server.is_serving():
msg = (await reader.read(constants.MSG_BYTES)).decode().strip()
if not msg:
log.debug("%s disconnected", self.client_class.__name__)
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_SIZE:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)
else:
log.debug("%s sent invalid command: %s", self.client_class.__name__, msg)
writer.write(b"Invalid command!")
await writer.drain()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
log.debug("%s connected", self.client_class.__name__)
writer.write(str(self._pool).encode())
await writer.drain()
await self._listen(reader, writer)
async def _serve_forever(self) -> None:
try:
async with self._server:
await self._server.serve_forever()
except CancelledError:
log.debug("%s stopped", self.__class__.__name__)
finally:
self.final_callback()
async def serve_forever(self) -> Task:
log.debug("Starting %s...", self.__class__.__name__)
self._server = await self.get_server_instance(self._client_connected_cb, **self._server_kwargs)
return create_task(self._serve_forever())
class UnixControlServer(ControlServer):
client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs)
async def get_server_instance(self, client_connected_cb, **kwargs) -> AbstractServer:
srv = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path))
return srv
def final_callback(self) -> None:
self._socket_path.unlink()
log.debug("Removed socket '%s'", str(self._socket_path))

View File

@ -1,30 +0,0 @@
import logging
from asyncio.exceptions import CancelledError
from asyncio.tasks import Task, create_task
from typing import Awaitable, Any
from .types import FinalCallbackT, CancelCallbackT
log = logging.getLogger(__name__)
async def wrap(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
log.info("Started %s", task_name)
try:
return await awaitable
except CancelledError:
log.info("Cancelling %s ...", task_name)
if callable(cancel_callback):
cancel_callback()
log.info("Cancelled %s", task_name)
finally:
if callable(final_callback):
final_callback()
log.info("Exiting %s", task_name)
def start_task(awaitable: Awaitable, task_name: str, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Task:
return create_task(wrap(awaitable, task_name, final_callback, cancel_callback), name=task_name)

View File

@ -1,6 +1,11 @@
from typing import Callable, Awaitable, Any
from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
CoroutineFunc = Callable[[...], Awaitable[Any]]
FinalCallbackT = Callable
EndCallbackT = Callable
CancelCallbackT = Callable
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

105
tests/test_pool.py Normal file
View File

@ -0,0 +1,105 @@
import asyncio
from unittest import TestCase
from unittest.mock import PropertyMock, patch
from asyncio_taskpool import pool
EMPTY_LIST, EMPTY_DICT = [], {}
class BaseTaskPoolTestCase(TestCase):
def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools')
# These three methods are called during initialization, so we mock them by default during setup
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters:
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools)
self._add_pool_patcher.stop()
self.pool_size_patcher.stop()
self.__str___patcher.stop()
def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool)
self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertTrue(self.task_pool._open)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with()
def test___str__(self):
self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None)
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self):
self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self):
self.task_pool._open = foo = 'foo'
self.assertEqual(foo, self.task_pool.is_open)
def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz']
self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self):
self.task_pool._num_cancelled = 33
self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self):
self.task_pool._num_ended = 3
self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self):
self.task_pool._num_cancelled = cancelled = 69
self.task_pool._num_ended = ended = 420
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'}
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test__task_name(self):
i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))

82
usage/USAGE.md Normal file
View File

@ -0,0 +1,82 @@
# Using `asyncio-taskpool`
## Minimal example for `SimpleTaskPool`
The minimum required setup is a "worker" coroutine function that can do something asynchronously, a main coroutine function that sets up the `SimpleTaskPool` and starts/stops the tasks as desired, eventually awaiting them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
```python
import logging
import asyncio
from asyncio_taskpool.pool import SimpleTaskPool
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger('asyncio_taskpool').addHandler(logging.StreamHandler())
async def work(n: int) -> None:
"""
Pseudo-worker function.
Counts up to an integer with a second of sleep before each iteration.
In a real-world use case, a worker function should probably have access
to some synchronisation primitive or shared resource to distribute work
between an arbitrary number of workers.
"""
for i in range(n):
await asyncio.sleep(1)
print("did", i)
async def main() -> None:
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool
if __name__ == '__main__':
asyncio.run(main())
```
### Output
```
SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0
Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2
did 0
did 0
did 0
Started SimpleTaskPool-0_Task-3
did 1
did 1
did 1
did 0
SimpleTaskPool-0 is closed!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2
did 2
did 2
did 3
did 3
Ended SimpleTaskPool-0_Task-0
Ended SimpleTaskPool-0_Task-1
did 4
did 4
```
## Advanced example
...

0
usage/__init__.py Normal file
View File

65
usage/example_server.py Normal file
View File

@ -0,0 +1,65 @@
import asyncio
import logging
from asyncio_taskpool import SimpleTaskPool, UnixControlServer
from asyncio_taskpool.constants import PACKAGE_NAME
logging.getLogger().setLevel(logging.NOTSET)
logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
async def work(item: int) -> None:
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
await asyncio.sleep(1)
print("worked on", item)
async def worker(q: asyncio.Queue) -> None:
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
while True:
# We want to block here, until we can get the next item from the queue.
item = await q.get()
# Since we want a nice cleanup upon cancellation, we put the "work" to be done in a `try:` block.
try:
await work(item)
except asyncio.CancelledError:
# If the task gets cancelled before our current "work" item is finished, we put it back into the queue
# because a worker must assume that some other worker can and will eventually finish the work on that item.
q.put_nowait(item)
# This takes us out of the loop. To enable cleanup we must re-raise the exception.
raise
finally:
# Since putting an item into the queue (even if it has just been taken out), increments the internal
# `._unfinished_tasks` counter in the queue, we must ensure that it is decremented before we end the
# iteration or leave the loop. Otherwise, the queue's `.join()` will block indefinitely.
q.task_done()
async def main() -> None:
# First, we set up a queue of items that our workers can "work" on.
q = asyncio.Queue()
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()
# Since we don't need any "work" done anymore, we can close our control server by cancelling the task.
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.
pool.stop_all()
pool.close()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values.
await pool.gather(return_exceptions=True)
await control_server_task
if __name__ == '__main__':
asyncio.run(main())