refactored TaskPool._map; fixed TaskPool.apply; helper functions in separate module; new exception class

This commit is contained in:
Daniil Fajnberg 2022-02-07 17:31:43 +01:00
parent ac903d9be7
commit 4ea815be65
6 changed files with 101 additions and 53 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.0.2
version = 0.0.3
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -24,3 +24,7 @@ class InvalidTaskID(PoolException):
class PoolStillOpen(PoolException):
pass
class NotCoroutine(PoolException):
pass

View File

@ -0,0 +1,24 @@
from asyncio.coroutines import iscoroutinefunction
from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")

View File

@ -1,13 +1,15 @@
import logging
from asyncio import gather
from asyncio.coroutines import iscoroutinefunction
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task
from functools import partial
from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions
from .helpers import execute_optional, star_function
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
@ -135,7 +137,7 @@ class BaseTaskPool:
self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
await execute_optional(custom_callback, args=(task_id,))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
"""
@ -157,7 +159,7 @@ class BaseTaskPool:
self._num_ended += 1
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
await execute_optional(custom_callback, args=(task_id,))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
@ -207,6 +209,8 @@ class BaseTaskPool:
Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
"""
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire()
@ -342,45 +346,47 @@ class TaskPool(BaseTaskPool):
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
assert isinstance(ids, list)
return ids
@staticmethod
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
async def _next_callback(self, task_id: int, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
reached_end = await self._start_next_task(func, args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
await execute_optional(end_callback, args=(task_id,))
async def _start_next_task(self, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> bool:
if self._interrupt_flag.is_set():
return True
try:
arg = next(args_iter)
await self._start_task(
star_function(func, next(args_iter), arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._next_callback, self, func=func, args_iter=args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
except StopIteration:
return
if arg_stars == 0:
return func(arg)
if arg_stars == 1:
return func(*arg)
if arg_stars == 2:
return func(**arg)
raise ValueError
return True
return False
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear()
args_iter = iter(args_iter)
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None or self._interrupt_flag.is_set():
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _start_next_coroutine()
await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks):
reached_end = await _start_next_coroutine()
reached_end = await self._start_next_task(func, args_iter, arg_stars, end_callback, cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
break
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
@ -403,6 +409,8 @@ class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None:
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func
self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
@ -437,13 +445,3 @@ class SimpleTaskPool(BaseTaskPool):
def stop_all(self) -> List[int]:
return self.stop(self.size)
async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
if kwargs is None:
kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -1,10 +1,15 @@
from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable
CancelCallbackT = Callable

View File

@ -123,9 +123,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, '_execute_function')
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
@ -134,11 +134,11 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, '_execute_function')
@patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
@ -151,9 +151,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock()
mock__execute_function.reset_mock()
mock_execute_optional.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
@ -163,7 +163,7 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation')
@ -213,14 +213,27 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
mock_is_open.return_value = ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
@ -231,8 +244,10 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
reset_mocks()
ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
@ -246,11 +261,13 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)