refactored TaskPool._map; fixed TaskPool.apply; helper functions in separate module; new exception class

This commit is contained in:
Daniil Fajnberg 2022-02-07 17:31:43 +01:00
parent ac903d9be7
commit 4ea815be65
6 changed files with 101 additions and 53 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.0.2 version = 0.0.3
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -24,3 +24,7 @@ class InvalidTaskID(PoolException):
class PoolStillOpen(PoolException): class PoolStillOpen(PoolException):
pass pass
class NotCoroutine(PoolException):
pass

View File

@ -0,0 +1,24 @@
from asyncio.coroutines import iscoroutinefunction
from typing import Any, Optional
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")

View File

@ -1,13 +1,15 @@
import logging import logging
from asyncio import gather from asyncio import gather
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Event, Semaphore from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from functools import partial
from math import inf from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Any, Awaitable, Dict, Iterable, Iterator, List
from . import exceptions from . import exceptions
from .helpers import execute_optional, star_function
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
@ -135,7 +137,7 @@ class BaseTaskPool:
self._cancelled[task_id] = self._running.pop(task_id) self._cancelled[task_id] = self._running.pop(task_id)
self._num_cancelled += 1 self._num_cancelled += 1
log.debug("Cancelled %s", self._task_name(task_id)) log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None: async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
""" """
@ -157,7 +159,7 @@ class BaseTaskPool:
self._num_ended += 1 self._num_ended += 1
self._enough_room.release() self._enough_room.release()
log.info("Ended %s", self._task_name(task_id)) log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, )) await execute_optional(custom_callback, args=(task_id,))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any: cancel_callback: CancelCallbackT = None) -> Any:
@ -207,6 +209,8 @@ class BaseTaskPool:
Raises: Raises:
`asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`. `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`.
""" """
if not iscoroutine(awaitable):
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if not (self.is_open or ignore_closed): if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire() await self._enough_room.acquire()
@ -342,45 +346,47 @@ class TaskPool(BaseTaskPool):
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)) ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)))
# TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions
assert isinstance(ids, list)
return ids
@staticmethod async def _next_callback(self, task_id: int, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
reached_end = await self._start_next_task(func, args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback)
if reached_end:
self._all_tasks_known_flag.set()
await execute_optional(end_callback, args=(task_id,))
async def _start_next_task(self, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> bool:
if self._interrupt_flag.is_set():
return True
try: try:
arg = next(args_iter) await self._start_task(
star_function(func, next(args_iter), arg_stars=arg_stars),
ignore_closed=True,
end_callback=partial(TaskPool._next_callback, self, func=func, args_iter=args_iter, arg_stars=arg_stars,
end_callback=end_callback, cancel_callback=cancel_callback),
cancel_callback=cancel_callback
)
except StopIteration: except StopIteration:
return return True
if arg_stars == 0: return False
return func(arg)
if arg_stars == 1:
return func(*arg)
if arg_stars == 2:
return func(**arg)
raise ValueError
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
if self._all_tasks_known_flag.is_set(): if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear() self._all_tasks_known_flag.clear()
args_iter = iter(args_iter) args_iter = iter(args_iter)
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None or self._interrupt_flag.is_set():
self._all_tasks_known_flag.set()
return True
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _start_next_coroutine()
await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks): for _ in range(num_tasks):
reached_end = await _start_next_coroutine() reached_end = await self._start_next_task(func, args_iter, arg_stars, end_callback, cancel_callback)
if reached_end: if reached_end:
self._all_tasks_known_flag.set()
break break
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
@ -403,6 +409,8 @@ class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None: name: str = None) -> None:
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func self._func: CoroutineFunc = func
self._args: ArgsT = args self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
@ -437,13 +445,3 @@ class SimpleTaskPool(BaseTaskPool):
def stop_all(self) -> List[int]: def stop_all(self) -> List[int]:
return self.stop(self.size) return self.stop(self.size)
async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None:
if kwargs is None:
kwargs = {}
if callable(func):
if iscoroutinefunction(func):
await func(*args, **kwargs)
else:
func(*args, **kwargs)

View File

@ -1,10 +1,15 @@
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[Awaitable[T], T]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCallbackT = Callable
CancelCallbackT = Callable CancelCallbackT = Callable

View File

@ -123,9 +123,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, '_execute_function') @patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3 self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task self.task_pool._running[task_id] = mock_task
@ -134,11 +134,11 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, '_execute_function') @patch.object(pool, 'execute_optional')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3 self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
@ -151,9 +151,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 1, self.task_pool._num_ended) self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value) self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock() mock__task_name.reset_mock()
mock__execute_function.reset_mock() mock_execute_optional.reset_mock()
# End cancelled task: # End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
@ -163,7 +163,7 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(ended + 2, self.task_pool._num_ended) self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value) self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending') @patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation') @patch.object(pool.BaseTaskPool, '_task_cancellation')
@ -213,14 +213,27 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
mock_create_task.return_value = mock_task = MagicMock() mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock() mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123 self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_not_called()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
mock_is_open.return_value = ignore_closed = False mock_is_open.return_value = ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed, await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, self.task_pool._counter) self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)
@ -231,8 +244,10 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
reset_mocks() reset_mocks()
ignore_closed = True ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed, output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output) self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count]) self.assertEqual(mock_task, self.task_pool._running[count])
@ -246,11 +261,13 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._enough_room._value = room self.task_pool._enough_room._value = room
del self.task_pool._running[count] del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException() mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e: with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed, await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e) self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running) self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value) self.assertEqual(room, self.task_pool._enough_room._value)