diff --git a/setup.cfg b/setup.cfg index 97fce9f..606903e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.0.2 +version = 0.0.3 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/exceptions.py b/src/asyncio_taskpool/exceptions.py index 0b25408..8176e4e 100644 --- a/src/asyncio_taskpool/exceptions.py +++ b/src/asyncio_taskpool/exceptions.py @@ -24,3 +24,7 @@ class InvalidTaskID(PoolException): class PoolStillOpen(PoolException): pass + + +class NotCoroutine(PoolException): + pass diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py new file mode 100644 index 0000000..d009a1f --- /dev/null +++ b/src/asyncio_taskpool/helpers.py @@ -0,0 +1,24 @@ +from asyncio.coroutines import iscoroutinefunction +from typing import Any, Optional + +from .types import T, AnyCallableT, ArgsT, KwArgsT + + +async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]: + if not callable(function): + return + if kwargs is None: + kwargs = {} + if iscoroutinefunction(function): + return await function(*args, **kwargs) + return function(*args, **kwargs) + + +def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: + if arg_stars == 0: + return function(arg) + if arg_stars == 1: + return function(*arg) + if arg_stars == 2: + return function(**arg) + raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 85d6c17..f54d989 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -1,13 +1,15 @@ import logging from asyncio import gather -from asyncio.coroutines import iscoroutinefunction +from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.exceptions import CancelledError from asyncio.locks import Event, Semaphore from asyncio.tasks import Task, create_task +from functools import partial from math import inf -from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Awaitable, Dict, Iterable, Iterator, List from . import exceptions +from .helpers import execute_optional, star_function from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT @@ -135,7 +137,7 @@ class BaseTaskPool: self._cancelled[task_id] = self._running.pop(task_id) self._num_cancelled += 1 log.debug("Cancelled %s", self._task_name(task_id)) - await _execute_function(custom_callback, args=(task_id, )) + await execute_optional(custom_callback, args=(task_id,)) async def _task_ending(self, task_id: int, custom_callback: EndCallbackT = None) -> None: """ @@ -157,7 +159,7 @@ class BaseTaskPool: self._num_ended += 1 self._enough_room.release() log.info("Ended %s", self._task_name(task_id)) - await _execute_function(custom_callback, args=(task_id, )) + await execute_optional(custom_callback, args=(task_id,)) async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Any: @@ -207,6 +209,8 @@ class BaseTaskPool: Raises: `asyncio_taskpool.exceptions.PoolIsClosed` if the pool has been closed and `ignore_closed` is `False`. """ + if not iscoroutine(awaitable): + raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") if not (self.is_open or ignore_closed): raise exceptions.PoolIsClosed("Cannot start new tasks") await self._enough_room.acquire() @@ -342,45 +346,47 @@ class TaskPool(BaseTaskPool): return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback) async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, - end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]: - return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num)) + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> List[int]: + ids = await gather(*(self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))) + # TODO: for some reason PyCharm wrongly claims that `gather` returns a tuple of exceptions + assert isinstance(ids, list) + return ids - @staticmethod - def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]: + async def _next_callback(self, task_id: int, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + reached_end = await self._start_next_task(func, args_iter, arg_stars=arg_stars, + end_callback=end_callback, cancel_callback=cancel_callback) + if reached_end: + self._all_tasks_known_flag.set() + await execute_optional(end_callback, args=(task_id,)) + + async def _start_next_task(self, func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0, + end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> bool: + if self._interrupt_flag.is_set(): + return True try: - arg = next(args_iter) + await self._start_task( + star_function(func, next(args_iter), arg_stars=arg_stars), + ignore_closed=True, + end_callback=partial(TaskPool._next_callback, self, func=func, args_iter=args_iter, arg_stars=arg_stars, + end_callback=end_callback, cancel_callback=cancel_callback), + cancel_callback=cancel_callback + ) except StopIteration: - return - if arg_stars == 0: - return func(arg) - if arg_stars == 1: - return func(*arg) - if arg_stars == 2: - return func(**arg) - raise ValueError + return True + return False async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: - + if not self.is_open: + raise exceptions.PoolIsClosed("Cannot start new tasks") if self._all_tasks_known_flag.is_set(): self._all_tasks_known_flag.clear() args_iter = iter(args_iter) - - async def _start_next_coroutine() -> bool: - cor = self._get_next_coroutine(func, args_iter, arg_stars) - if cor is None or self._interrupt_flag.is_set(): - self._all_tasks_known_flag.set() - return True - await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback) - return False - - async def _start_next(task_id: int) -> None: - await _start_next_coroutine() - await _execute_function(end_callback, args=(task_id, )) - for _ in range(num_tasks): - reached_end = await _start_next_coroutine() + reached_end = await self._start_next_task(func, args_iter, arg_stars, end_callback, cancel_callback) if reached_end: + self._all_tasks_known_flag.set() break async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1, @@ -403,6 +409,8 @@ class SimpleTaskPool(BaseTaskPool): def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, name: str = None) -> None: + if not iscoroutinefunction(func): + raise exceptions.NotCoroutine(f"Not a coroutine function: {func}") self._func: CoroutineFunc = func self._args: ArgsT = args self._kwargs: KwArgsT = kwargs if kwargs is not None else {} @@ -437,13 +445,3 @@ class SimpleTaskPool(BaseTaskPool): def stop_all(self) -> List[int]: return self.stop(self.size) - - -async def _execute_function(func: Callable, args: ArgsT = (), kwargs: KwArgsT = None) -> None: - if kwargs is None: - kwargs = {} - if callable(func): - if iscoroutinefunction(func): - await func(*args, **kwargs) - else: - func(*args, **kwargs) diff --git a/src/asyncio_taskpool/types.py b/src/asyncio_taskpool/types.py index 649e58a..f0278c0 100644 --- a/src/asyncio_taskpool/types.py +++ b/src/asyncio_taskpool/types.py @@ -1,10 +1,15 @@ from asyncio.streams import StreamReader, StreamWriter -from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union +from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union +T = TypeVar('T') + ArgsT = Iterable[Any] KwArgsT = Mapping[str, Any] + +AnyCallableT = Callable[[...], Union[Awaitable[T], T]] CoroutineFunc = Callable[[...], Awaitable[Any]] + EndCallbackT = Callable CancelCallbackT = Callable diff --git a/tests/test_pool.py b/tests/test_pool.py index bf58573..9bfa20e 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -123,9 +123,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): i = 123 self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) - @patch.object(pool, '_execute_function') + @patch.object(pool, 'execute_optional') @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) - async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): + async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() self.task_pool._num_cancelled = cancelled = 3 self.task_pool._running[task_id] = mock_task @@ -134,11 +134,11 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) mock__task_name.assert_called_with(task_id) - mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) - @patch.object(pool, '_execute_function') + @patch.object(pool, 'execute_optional') @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) - async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): + async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() self.task_pool._num_ended = ended = 3 self.task_pool._enough_room._value = room = 123 @@ -151,9 +151,9 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.assertEqual(ended + 1, self.task_pool._num_ended) self.assertEqual(room + 1, self.task_pool._enough_room._value) mock__task_name.assert_called_with(task_id) - mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) mock__task_name.reset_mock() - mock__execute_function.reset_mock() + mock_execute_optional.reset_mock() # End cancelled task: self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) @@ -163,7 +163,7 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.assertEqual(ended + 2, self.task_pool._num_ended) self.assertEqual(room + 2, self.task_pool._enough_room._value) mock__task_name.assert_called_with(task_id) - mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) @patch.object(pool.BaseTaskPool, '_task_ending') @patch.object(pool.BaseTaskPool, '_task_cancellation') @@ -213,14 +213,27 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): mock_create_task.return_value = mock_task = MagicMock() mock__task_wrapper.return_value = mock_wrapped = MagicMock() - mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() + mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock() self.task_pool._counter = count = 123 self.task_pool._enough_room._value = room = 123 + with self.assertRaises(exceptions.NotCoroutine): + await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertEqual(count, self.task_pool._counter) + self.assertNotIn(count, self.task_pool._running) + self.assertEqual(room, self.task_pool._enough_room._value) + mock_is_open.assert_not_called() + mock__task_name.assert_not_called() + mock__task_wrapper.assert_not_called() + mock_create_task.assert_not_called() + reset_mocks() + mock_is_open.return_value = ignore_closed = False + mock_awaitable = mock_coroutine() with self.assertRaises(exceptions.PoolIsClosed): await self.task_pool._start_task(mock_awaitable, ignore_closed, end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + await mock_awaitable self.assertEqual(count, self.task_pool._counter) self.assertNotIn(count, self.task_pool._running) self.assertEqual(room, self.task_pool._enough_room._value) @@ -231,8 +244,10 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): reset_mocks() ignore_closed = True + mock_awaitable = mock_coroutine() output = await self.task_pool._start_task(mock_awaitable, ignore_closed, end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + await mock_awaitable self.assertEqual(count, output) self.assertEqual(count + 1, self.task_pool._counter) self.assertEqual(mock_task, self.task_pool._running[count]) @@ -246,11 +261,13 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): self.task_pool._enough_room._value = room del self.task_pool._running[count] + mock_awaitable = mock_coroutine() mock_create_task.side_effect = test_exception = TestException() with self.assertRaises(TestException) as e: await self.task_pool._start_task(mock_awaitable, ignore_closed, end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) self.assertEqual(test_exception, e) + await mock_awaitable self.assertEqual(count + 1, self.task_pool._counter) self.assertNotIn(count, self.task_pool._running) self.assertEqual(room, self.task_pool._enough_room._value)