generated from daniil-berg/boilerplate-py
Compare commits
2 Commits
80fc91ec47
...
56d38a3b44
Author | SHA1 | Date | |
---|---|---|---|
56d38a3b44 | |||
0daed04167 |
@ -23,7 +23,7 @@ This module should **not** be considered part of the public API.
|
||||
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
||||
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -32,7 +32,7 @@ ArgsT = Iterable[Any]
|
||||
KwArgsT = Mapping[str, Any]
|
||||
|
||||
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
||||
CoroutineFunc = Callable[[...], Coroutine]
|
||||
|
||||
EndCB = Callable
|
||||
CancelCB = Callable
|
||||
|
@ -326,6 +326,8 @@ class BaseTaskPool:
|
||||
"""
|
||||
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
|
||||
await self._enough_room.acquire()
|
||||
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
|
||||
# Possibly make use of the task group register for that.
|
||||
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||
async with group_reg:
|
||||
task_id = self._num_started
|
||||
@ -528,6 +530,8 @@ class BaseTaskPool:
|
||||
self._tasks_cancelled.clear()
|
||||
self._tasks_running.clear()
|
||||
self._closed = True
|
||||
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
|
||||
# await it to allow blocking until a closing command comes from a server.
|
||||
|
||||
|
||||
class TaskPool(BaseTaskPool):
|
||||
@ -566,36 +570,50 @@ class TaskPool(BaseTaskPool):
|
||||
return name
|
||||
i += 1
|
||||
|
||||
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
"""
|
||||
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
|
||||
Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
|
||||
|
||||
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
||||
|
||||
Args:
|
||||
group_name:
|
||||
Name of the task group to add the new task to.
|
||||
Name of the task group to add the new tasks to.
|
||||
func:
|
||||
The coroutine function to be run as a task within the task pool.
|
||||
The coroutine function to be run in `num` tasks within the task pool.
|
||||
args (optional):
|
||||
The positional arguments to pass into the function call.
|
||||
The positional arguments to pass into each function call.
|
||||
kwargs (optional):
|
||||
The keyword-arguments to pass into the function call.
|
||||
The keyword-arguments to pass into each function call.
|
||||
num (optional):
|
||||
The number of tasks to spawn with the specified parameters.
|
||||
end_callback (optional):
|
||||
A callback to execute after the task has ended.
|
||||
A callback to execute after each task has ended.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
cancel_callback (optional):
|
||||
A callback to execute after cancellation of the task.
|
||||
A callback to execute after cancellation of each task.
|
||||
It is run with the task's ID as its only positional argument.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# TODO: Add exception logging
|
||||
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
|
||||
cancel_callback=cancel_callback) for _ in range(num)))
|
||||
for i in range(num):
|
||||
try:
|
||||
coroutine = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# This means there was probably something wrong with the function arguments.
|
||||
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
||||
continue
|
||||
try:
|
||||
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
||||
cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||
# more tasks and can return immediately.
|
||||
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
||||
coroutine.close()
|
||||
return
|
||||
|
||||
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||
@ -650,8 +668,8 @@ class TaskPool(BaseTaskPool):
|
||||
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
|
||||
self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||
meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
|
||||
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||
return group_name
|
||||
|
||||
@staticmethod
|
||||
@ -690,27 +708,31 @@ class TaskPool(BaseTaskPool):
|
||||
The callback that was specified to execute after cancellation of the task (and the next one).
|
||||
It is run with the task's ID as its only positional argument.
|
||||
"""
|
||||
map_semaphore = Semaphore(num_concurrent)
|
||||
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
|
||||
for next_arg in arg_iter:
|
||||
# When the number of running tasks spawned by this method reaches the specified maximum,
|
||||
# this next line will block, until one of them ends and releases the semaphore.
|
||||
await map_semaphore.acquire()
|
||||
semaphore = Semaphore(num_concurrent)
|
||||
release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
|
||||
for i, next_arg in enumerate(arg_iter):
|
||||
semaphore_acquired = False
|
||||
try:
|
||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
# This means that no more tasks are supposed to be created from this `arg_iter`;
|
||||
# thus, we can forget about the rest of the arguments.
|
||||
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
|
||||
map_semaphore.release()
|
||||
return
|
||||
coroutine = star_function(func, next_arg, arg_stars=arg_stars)
|
||||
except Exception as e:
|
||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
||||
# It does not imply an error within the task itself.
|
||||
log.exception("%s occurred while trying to create task: %s(%s%s)",
|
||||
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
|
||||
map_semaphore.release()
|
||||
# This means there was probably something wrong with the function arguments.
|
||||
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
|
||||
str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
|
||||
continue
|
||||
try:
|
||||
# When the number of running tasks spawned by this method reaches the specified maximum,
|
||||
# this next line will block, until one of them ends and releases the semaphore.
|
||||
semaphore_acquired = await semaphore.acquire()
|
||||
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
|
||||
end_callback=release_cb, cancel_callback=cancel_callback)
|
||||
except CancelledError:
|
||||
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
|
||||
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
|
||||
coroutine.close()
|
||||
if semaphore_acquired:
|
||||
semaphore.release()
|
||||
return
|
||||
|
||||
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||
@ -924,6 +946,7 @@ class SimpleTaskPool(BaseTaskPool):
|
||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
||||
for _ in range(num)
|
||||
)
|
||||
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
|
||||
await gather(*start_coroutines)
|
||||
|
||||
def start(self, num: int) -> str:
|
||||
|
@ -455,49 +455,65 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
self.assertEqual(expected_output, output)
|
||||
|
||||
@patch.object(pool.TaskPool, '_start_task')
|
||||
async def test__apply_num(self, mock__start_task: AsyncMock):
|
||||
group_name = FOO + BAR
|
||||
mock_awaitable = object()
|
||||
mock_func = MagicMock(return_value=mock_awaitable)
|
||||
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
|
||||
async def test__apply_spawner(self, mock__start_task: AsyncMock):
|
||||
grp_name = FOO + BAR
|
||||
mock_awaitable1, mock_awaitable2 = object(), object()
|
||||
mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
||||
args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
|
||||
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
|
||||
mock__start_task.assert_has_awaits(3 * [
|
||||
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
|
||||
mock_func.assert_has_calls(num * [call(*args, **kw)])
|
||||
mock__start_task.assert_has_awaits([
|
||||
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||
])
|
||||
|
||||
mock_func.reset_mock()
|
||||
mock_func.reset_mock(side_effect=True)
|
||||
mock__start_task.reset_mock()
|
||||
|
||||
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
|
||||
mock_func.assert_has_calls(num * [call(*args)])
|
||||
mock__start_task.assert_has_awaits(num * [
|
||||
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
# Simulate cancellation while the second task is being started.
|
||||
mock__start_task.side_effect = [None, CancelledError, None]
|
||||
mock_coroutine_to_close = MagicMock()
|
||||
mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
||||
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
|
||||
mock_func.assert_has_calls(2 * [call(*args)])
|
||||
mock__start_task.assert_has_awaits([
|
||||
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||
])
|
||||
mock_coroutine_to_close.close.assert_called_once_with()
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
|
||||
@patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
|
||||
@patch.object(pool, 'TaskGroupRegister')
|
||||
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
|
||||
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
|
||||
mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
|
||||
mock__generate_group_name.return_value = generated_name = 'name 123'
|
||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||
mock__apply_num.return_value = mock_apply_coroutine = object()
|
||||
mock__apply_spawner.return_value = mock_apply_coroutine = object()
|
||||
mock_create_task.return_value = fake_task = object()
|
||||
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
|
||||
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
|
||||
self.task_pool._task_groups = {group_name: 'causes error'}
|
||||
with self.assertRaises(exceptions.InvalidGroupName):
|
||||
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
||||
mock__check_start.assert_called_once_with(function=mock_func)
|
||||
mock__apply_spawner.assert_not_called()
|
||||
mock_create_task.assert_not_called()
|
||||
|
||||
mock__check_start.reset_mock()
|
||||
self.task_pool._task_groups = {}
|
||||
|
||||
def check_assertions(_group_name, _output):
|
||||
self.assertEqual(_group_name, _output)
|
||||
mock__check_start.assert_called_once_with(function=mock_func)
|
||||
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
|
||||
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
|
||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||
mock_create_task.assert_called_once_with(mock_apply_coroutine)
|
||||
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
|
||||
|
||||
@ -507,7 +523,7 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
|
||||
mock__check_start.reset_mock()
|
||||
self.task_pool._task_groups.clear()
|
||||
mock__apply_num.reset_mock()
|
||||
mock__apply_spawner.reset_mock()
|
||||
mock_create_task.reset_mock()
|
||||
|
||||
output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
|
||||
@ -532,20 +548,20 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
n = 2
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
|
||||
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||
awaitable = 'totally an awaitable'
|
||||
mock_star_function.side_effect = [awaitable, Exception(), awaitable]
|
||||
awaitable1, awaitable2 = 'totally an awaitable', object()
|
||||
mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
|
||||
arg1, arg2, bad = 123456789, 'function argument', None
|
||||
args = [arg1, bad, arg2]
|
||||
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
||||
grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||
# We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then
|
||||
# acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked.
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||
# We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
|
||||
self.assertTrue(semaphore.locked())
|
||||
mock_semaphore_cls.assert_called_once_with(n)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_has_awaits(2 * [
|
||||
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||
mock__start_task.assert_has_awaits([
|
||||
call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
|
||||
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
|
||||
])
|
||||
mock_star_function.assert_has_calls([
|
||||
call(mock_func, arg1, arg_stars=stars),
|
||||
@ -556,17 +572,50 @@ class TaskPoolTestCase(CommonTestCase):
|
||||
mock_semaphore_cls.reset_mock()
|
||||
mock__get_map_end_callback.reset_mock()
|
||||
mock__start_task.reset_mock()
|
||||
mock_star_function.reset_mock()
|
||||
mock_star_function.reset_mock(side_effect=True)
|
||||
|
||||
# With a CancelledError thrown while starting a task:
|
||||
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
|
||||
mock_star_function.side_effect = CancelledError()
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||
self.assertFalse(semaphore.locked())
|
||||
# With a CancelledError thrown while acquiring the semaphore:
|
||||
mock_acquire = AsyncMock(side_effect=[True, CancelledError])
|
||||
mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
|
||||
mock_star_function.return_value = mock_coroutine = MagicMock()
|
||||
arg_it = iter(arg for arg in (arg1, arg2, FOO))
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
|
||||
mock_semaphore_cls.assert_called_once_with(n)
|
||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||
mock__start_task.assert_not_called()
|
||||
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
|
||||
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
|
||||
mock_star_function.assert_has_calls([
|
||||
call(mock_func, arg1, arg_stars=stars),
|
||||
call(mock_func, arg2, arg_stars=stars)
|
||||
])
|
||||
mock_acquire.assert_has_awaits([call(), call()])
|
||||
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
|
||||
end_callback=map_cb, cancel_callback=cancel_cb)
|
||||
mock_coroutine.close.assert_called_once_with()
|
||||
mock_semaphore.release.assert_not_called()
|
||||
self.assertEqual(FOO, next(arg_it))
|
||||
|
||||
mock_acquire.reset_mock(side_effect=True)
|
||||
mock_semaphore_cls.reset_mock()
|
||||
mock__get_map_end_callback.reset_mock()
|
||||
mock__start_task.reset_mock()
|
||||
mock_star_function.reset_mock(side_effect=True)
|
||||
|
||||
# With a CancelledError thrown while starting the task:
|
||||
mock__start_task.side_effect = [None, CancelledError]
|
||||
arg_it = iter(arg for arg in (arg1, arg2, FOO))
|
||||
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
|
||||
mock_semaphore_cls.assert_called_once_with(n)
|
||||
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
|
||||
mock_star_function.assert_has_calls([
|
||||
call(mock_func, arg1, arg_stars=stars),
|
||||
call(mock_func, arg2, arg_stars=stars)
|
||||
])
|
||||
mock_acquire.assert_has_awaits([call(), call()])
|
||||
mock__start_task.assert_has_awaits(2 * [
|
||||
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||
])
|
||||
mock_coroutine.close.assert_called_once_with()
|
||||
mock_semaphore.release.assert_called_once_with()
|
||||
self.assertEqual(FOO, next(arg_it))
|
||||
|
||||
@patch.object(pool, 'create_task')
|
||||
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
|
||||
@ -695,6 +744,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.base_class_init_patcher.stop()
|
||||
super().tearDown()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
|
||||
|
Loading…
Reference in New Issue
Block a user