generated from daniil-berg/boilerplate-py
Compare commits
3 Commits
d7cd16c540
...
e3bbb05eac
Author | SHA1 | Date | |
---|---|---|---|
e3bbb05eac | |||
36527ccffc | |||
d047b99119 |
@ -10,4 +10,3 @@ skip_covered = False
|
|||||||
exclude_lines =
|
exclude_lines =
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if __name__ == ['"]__main__['"]:
|
if __name__ == ['"]__main__['"]:
|
||||||
if sys.version_info.+:
|
|
||||||
|
4
.github/workflows/main.yaml
vendored
4
.github/workflows/main.yaml
vendored
@ -1,5 +1,7 @@
|
|||||||
name: CI
|
name: CI
|
||||||
on: [push]
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master]
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
name: Python ${{ matrix.python-version }} Tests
|
name: Python ${{ matrix.python-version }} Tests
|
||||||
|
@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
|
|||||||
author = 'Daniil Fajnberg'
|
author = 'Daniil Fajnberg'
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '1.0.0-beta'
|
release = '1.0.2'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 1.0.1
|
version = 1.0.2
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -21,8 +21,13 @@ This module should **not** be considered part of the public API.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
|
|
||||||
|
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
|
||||||
|
|
||||||
DEFAULT_TASK_GROUP = 'default'
|
DEFAULT_TASK_GROUP = 'default'
|
||||||
|
|
||||||
SESSION_MSG_BYTES = 1024 * 100
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
|||||||
|
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import sys
|
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from inspect import getdoc
|
from inspect import getdoc
|
||||||
from typing import Any, Callable, Optional, Type, Union
|
from typing import Any, Callable, Optional, Type, Union
|
||||||
|
|
||||||
|
from .constants import PYTHON_BEFORE_39
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
|
|||||||
|
|
||||||
|
|
||||||
# Starting with Python 3.9, this is thankfully no longer necessary.
|
# Starting with Python 3.9, this is thankfully no longer necessary.
|
||||||
if sys.version_info[:2] < (3, 9):
|
if PYTHON_BEFORE_39:
|
||||||
classmethod = ClassMethodWorkaround
|
classmethod = ClassMethodWorkaround
|
||||||
else:
|
else:
|
||||||
classmethod = builtins.classmethod
|
classmethod = builtins.classmethod
|
||||||
|
@ -28,6 +28,7 @@ For further details about the classes check their respective documentation.
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Semaphore
|
||||||
@ -37,7 +38,7 @@ from math import inf
|
|||||||
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from .internals.constants import DEFAULT_TASK_GROUP
|
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
|
||||||
from .internals.group_register import TaskGroupRegister
|
from .internals.group_register import TaskGroupRegister
|
||||||
from .internals.helpers import execute_optional, star_function
|
from .internals.helpers import execute_optional, star_function
|
||||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||||
@ -360,6 +361,23 @@ class BaseTaskPool:
|
|||||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary to unpack in a `Task.cancel()` method.
|
||||||
|
|
||||||
|
This method exists to ensure proper compatibility with older Python versions.
|
||||||
|
If `msg` is `None`, an empty dictionary is returned.
|
||||||
|
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
|
||||||
|
Otherwise the keyword dictionary contains the `msg` parameter.
|
||||||
|
"""
|
||||||
|
if msg is None:
|
||||||
|
return {}
|
||||||
|
if PYTHON_BEFORE_39:
|
||||||
|
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
|
||||||
|
return {}
|
||||||
|
return {'msg': msg}
|
||||||
|
|
||||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels the tasks with the specified IDs.
|
Cancels the tasks with the specified IDs.
|
||||||
@ -378,8 +396,9 @@ class BaseTaskPool:
|
|||||||
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
||||||
"""
|
"""
|
||||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel(msg=msg)
|
task.cancel(**kw)
|
||||||
|
|
||||||
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
||||||
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
||||||
@ -392,7 +411,7 @@ class BaseTaskPool:
|
|||||||
self._meta_tasks_cancelled.update(meta_tasks)
|
self._meta_tasks_cancelled.update(meta_tasks)
|
||||||
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
||||||
|
|
||||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
|
||||||
"""
|
"""
|
||||||
Removes all tasks from the specified group and cancels them.
|
Removes all tasks from the specified group and cancels them.
|
||||||
|
|
||||||
@ -406,7 +425,7 @@ class BaseTaskPool:
|
|||||||
self._cancel_group_meta_tasks(group_name)
|
self._cancel_group_meta_tasks(group_name)
|
||||||
while group_reg:
|
while group_reg:
|
||||||
try:
|
try:
|
||||||
self._tasks_running[group_reg.pop()].cancel(msg=msg)
|
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
||||||
@ -433,7 +452,8 @@ class BaseTaskPool:
|
|||||||
group_reg = self._task_groups.pop(group_name)
|
group_reg = self._task_groups.pop(group_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
kw = self._get_cancel_kw(msg)
|
||||||
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
log.debug("%s forgot task group %s", str(self), group_name)
|
log.debug("%s forgot task group %s", str(self), group_name)
|
||||||
|
|
||||||
def cancel_all(self, msg: str = None) -> None:
|
def cancel_all(self, msg: str = None) -> None:
|
||||||
@ -448,9 +468,10 @@ class BaseTaskPool:
|
|||||||
msg (optional): Passed to the `Task.cancel()` method of every task.
|
msg (optional): Passed to the `Task.cancel()` method of every task.
|
||||||
"""
|
"""
|
||||||
log.warning("%s cancelling all tasks!", str(self))
|
log.warning("%s cancelling all tasks!", str(self))
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
while self._task_groups:
|
while self._task_groups:
|
||||||
group_name, group_reg = self._task_groups.popitem()
|
group_name, group_reg = self._task_groups.popitem()
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
|
|
||||||
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
||||||
"""
|
"""
|
||||||
@ -604,7 +625,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
# This means there was probably something wrong with the function arguments.
|
# This means there was probably something wrong with the function arguments.
|
||||||
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
||||||
continue
|
continue # TODO: Consider returning instead of continuing
|
||||||
try:
|
try:
|
||||||
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
||||||
cancel_callback=cancel_callback)
|
cancel_callback=cancel_callback)
|
||||||
@ -941,13 +962,24 @@ class SimpleTaskPool(BaseTaskPool):
|
|||||||
|
|
||||||
async def _start_num(self, num: int, group_name: str) -> None:
|
async def _start_num(self, num: int, group_name: str) -> None:
|
||||||
"""Starts `num` new tasks in group `group_name`."""
|
"""Starts `num` new tasks in group `group_name`."""
|
||||||
start_coroutines = (
|
for i in range(num):
|
||||||
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
|
try:
|
||||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
coroutine = self._func(*self._args, **self._kwargs)
|
||||||
for _ in range(num)
|
except Exception as e:
|
||||||
)
|
# This means there was probably something wrong with the function arguments.
|
||||||
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
|
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
await gather(*start_coroutines)
|
str(e.__class__.__name__), str(self), self._func.__name__,
|
||||||
|
repr(self._args), repr(self._kwargs))
|
||||||
|
continue # TODO: Consider returning instead of continuing
|
||||||
|
try:
|
||||||
|
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
|
||||||
|
cancel_callback=self._cancel_callback)
|
||||||
|
except CancelledError:
|
||||||
|
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||||
|
# more tasks and can return immediately.
|
||||||
|
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
||||||
|
coroutine.close()
|
||||||
|
return
|
||||||
|
|
||||||
def start(self, num: int) -> str:
|
def start(self, num: int) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -18,10 +18,11 @@ __doc__ = """
|
|||||||
Unittests for the `asyncio_taskpool.helpers` module.
|
Unittests for the `asyncio_taskpool.helpers` module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
from unittest import IsolatedAsyncioTestCase, TestCase
|
from unittest import IsolatedAsyncioTestCase, TestCase
|
||||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool.internals import constants
|
||||||
from asyncio_taskpool.internals import helpers
|
from asyncio_taskpool.internals import helpers
|
||||||
|
|
||||||
|
|
||||||
@ -152,3 +153,15 @@ class ClassMethodWorkaroundTestCase(TestCase):
|
|||||||
cls = None
|
cls = None
|
||||||
output = instance.__get__(obj, cls)
|
output = instance.__get__(obj, cls)
|
||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
def test_correct_class(self):
|
||||||
|
is_older_python = constants.PYTHON_BEFORE_39
|
||||||
|
try:
|
||||||
|
constants.PYTHON_BEFORE_39 = True
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
|
||||||
|
constants.PYTHON_BEFORE_39 = False
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(classmethod, helpers.classmethod)
|
||||||
|
finally:
|
||||||
|
constants.PYTHON_BEFORE_39 = is_older_python
|
||||||
|
@ -311,13 +311,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_not_called()
|
mock__task_name.assert_not_called()
|
||||||
|
|
||||||
|
@patch('warnings.warn')
|
||||||
|
def test__get_cancel_kw(self, mock_warn: MagicMock):
|
||||||
|
msg = None
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
msg = 'something'
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_called_once()
|
||||||
|
mock_warn.reset_mock()
|
||||||
|
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
|
||||||
|
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||||
def test_cancel(self, mock__get_running_task: MagicMock):
|
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
mock__get_cancel_kw.assert_called_once_with(FOO)
|
||||||
|
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
|
||||||
|
|
||||||
def test__cancel_group_meta_tasks(self):
|
def test__cancel_group_meta_tasks(self):
|
||||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||||
@ -336,6 +355,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
||||||
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
||||||
|
kw = {BAR: 10, BAZ: 20}
|
||||||
task_id = 555
|
task_id = 555
|
||||||
mock_cancel = MagicMock()
|
mock_cancel = MagicMock()
|
||||||
|
|
||||||
@ -347,27 +367,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
class MockRegister(set, MagicMock):
|
class MockRegister(set, MagicMock):
|
||||||
pass
|
pass
|
||||||
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
|
||||||
mock_cancel.assert_called_once_with(msg=FOO)
|
mock_cancel.assert_called_once_with(**kw)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
||||||
with self.assertRaises(exceptions.InvalidGroupName):
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
self.task_pool.cancel_group(BAR)
|
self.task_pool.cancel_group(BAR)
|
||||||
mock__cancel_and_remove_all_from_group.assert_not_called()
|
mock__cancel_and_remove_all_from_group.assert_not_called()
|
||||||
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
|
mock__get_cancel_kw.assert_called_once_with(BAR)
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
mock_group_reg = MagicMock()
|
mock_group_reg = MagicMock()
|
||||||
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
||||||
self.assertIsNone(self.task_pool.cancel_all('msg'))
|
self.assertIsNone(self.task_pool.cancel_all(BAZ))
|
||||||
|
mock__get_cancel_kw.assert_called_once_with(BAZ)
|
||||||
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
||||||
call(BAR, mock_group_reg, msg='msg'),
|
call(BAR, mock_group_reg, **fake_cancel_kw),
|
||||||
call(FOO, mock_group_reg, msg='msg')
|
call(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
])
|
])
|
||||||
|
|
||||||
def test__pop_ended_meta_tasks(self):
|
def test__pop_ended_meta_tasks(self):
|
||||||
@ -764,18 +790,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_task')
|
@patch.object(pool.SimpleTaskPool, '_start_task')
|
||||||
async def test__start_num(self, mock__start_task: AsyncMock):
|
async def test__start_num(self, mock__start_task: AsyncMock):
|
||||||
fake_coroutine = object()
|
|
||||||
self.task_pool._func = MagicMock(return_value=fake_coroutine)
|
|
||||||
num = 3
|
|
||||||
group_name = FOO + BAR + 'abc'
|
group_name = FOO + BAR + 'abc'
|
||||||
|
mock_awaitable1, mock_awaitable2 = object(), object()
|
||||||
|
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
||||||
|
num = 3
|
||||||
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
self.task_pool._func.assert_has_calls(num * [
|
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
call(*self.task_pool._args, **self.task_pool._kwargs)
|
call_kw = {
|
||||||
])
|
'group_name': group_name,
|
||||||
mock__start_task.assert_has_awaits(num * [
|
'end_callback': self.task_pool._end_callback,
|
||||||
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
|
'cancel_callback': self.task_pool._cancel_callback
|
||||||
cancel_callback=self.task_pool._cancel_callback)
|
}
|
||||||
])
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
|
||||||
|
|
||||||
|
self.task_pool._func.reset_mock(side_effect=True)
|
||||||
|
mock__start_task.reset_mock()
|
||||||
|
|
||||||
|
# Simulate cancellation while the second task is being started.
|
||||||
|
mock__start_task.side_effect = [None, CancelledError, None]
|
||||||
|
mock_coroutine_to_close = MagicMock()
|
||||||
|
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
||||||
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
|
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
|
||||||
|
mock_coroutine_to_close.close.assert_called_once_with()
|
||||||
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
||||||
|
Loading…
Reference in New Issue
Block a user