Fix cancel message bug for Python 3.8; test coverage workaround for Python version conditions

This commit is contained in:
Daniil Fajnberg 2022-04-10 10:41:32 +02:00
parent d7cd16c540
commit d047b99119
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
7 changed files with 87 additions and 21 deletions

View File

@ -10,4 +10,3 @@ skip_covered = False
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == ['"]__main__['"]: if __name__ == ['"]__main__['"]:
if sys.version_info.+:

View File

@ -1,5 +1,7 @@
name: CI name: CI
on: [push] on:
push:
branches: [master]
jobs: jobs:
tests: tests:
name: Python ${{ matrix.python-version }} Tests name: Python ${{ matrix.python-version }} Tests

View File

@ -21,8 +21,13 @@ This module should **not** be considered part of the public API.
""" """
import sys
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DEFAULT_TASK_GROUP = 'default' DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100

View File

@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
import builtins import builtins
import sys
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Callable, Optional, Type, Union from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
# Starting with Python 3.9, this is thankfully no longer necessary. # Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9): if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround classmethod = ClassMethodWorkaround
else: else:
classmethod = builtins.classmethod classmethod = builtins.classmethod

View File

@ -28,6 +28,7 @@ For further details about the classes check their respective documentation.
import logging import logging
import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
@ -37,7 +38,7 @@ from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -360,6 +361,23 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@staticmethod
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
"""
Returns a dictionary to unpack in a `Task.cancel()` method.
This method exists to ensure proper compatibility with older Python versions.
If `msg` is `None`, an empty dictionary is returned.
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
Otherwise the keyword dictionary contains the `msg` parameter.
"""
if msg is None:
return {}
if PYTHON_BEFORE_39:
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
return {}
return {'msg': msg}
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
""" """
Cancels the tasks with the specified IDs. Cancels the tasks with the specified IDs.
@ -378,8 +396,9 @@ class BaseTaskPool:
`InvalidTaskID`: One of the `task_ids` is not known to the pool. `InvalidTaskID`: One of the `task_ids` is not known to the pool.
""" """
tasks = [self._get_running_task(task_id) for task_id in task_ids] tasks = [self._get_running_task(task_id) for task_id in task_ids]
kw = self._get_cancel_kw(msg)
for task in tasks: for task in tasks:
task.cancel(msg=msg) task.cancel(**kw)
def _cancel_group_meta_tasks(self, group_name: str) -> None: def _cancel_group_meta_tasks(self, group_name: str) -> None:
"""Cancels and forgets all meta tasks associated with the task group named `group_name`.""" """Cancels and forgets all meta tasks associated with the task group named `group_name`."""
@ -392,7 +411,7 @@ class BaseTaskPool:
self._meta_tasks_cancelled.update(meta_tasks) self._meta_tasks_cancelled.update(meta_tasks)
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name) log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None: def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
""" """
Removes all tasks from the specified group and cancels them. Removes all tasks from the specified group and cancels them.
@ -406,7 +425,7 @@ class BaseTaskPool:
self._cancel_group_meta_tasks(group_name) self._cancel_group_meta_tasks(group_name)
while group_reg: while group_reg:
try: try:
self._tasks_running[group_reg.pop()].cancel(msg=msg) self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
except KeyError: except KeyError:
continue continue
log.debug("%s cancelled tasks from group %s", str(self), group_name) log.debug("%s cancelled tasks from group %s", str(self), group_name)
@ -433,7 +452,8 @@ class BaseTaskPool:
group_reg = self._task_groups.pop(group_name) group_reg = self._task_groups.pop(group_name)
except KeyError: except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) kw = self._get_cancel_kw(msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
log.debug("%s forgot task group %s", str(self), group_name) log.debug("%s forgot task group %s", str(self), group_name)
def cancel_all(self, msg: str = None) -> None: def cancel_all(self, msg: str = None) -> None:
@ -448,9 +468,10 @@ class BaseTaskPool:
msg (optional): Passed to the `Task.cancel()` method of every task. msg (optional): Passed to the `Task.cancel()` method of every task.
""" """
log.warning("%s cancelling all tasks!", str(self)) log.warning("%s cancelling all tasks!", str(self))
kw = self._get_cancel_kw(msg)
while self._task_groups: while self._task_groups:
group_name, group_reg = self._task_groups.popitem() group_name, group_reg = self._task_groups.popitem()
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
def _pop_ended_meta_tasks(self) -> Set[Task]: def _pop_ended_meta_tasks(self) -> Set[Task]:
""" """

View File

@ -18,10 +18,11 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module. Unittests for the `asyncio_taskpool.helpers` module.
""" """
import importlib
from unittest import IsolatedAsyncioTestCase, TestCase from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers from asyncio_taskpool.internals import helpers
@ -152,3 +153,15 @@ class ClassMethodWorkaroundTestCase(TestCase):
cls = None cls = None
output = instance.__get__(obj, cls) output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -311,13 +311,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task') @patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock): def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9 task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock() mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO)) self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_group_meta_tasks(self): def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock() mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -336,6 +355,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks') @patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock): def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555 task_id = 555
mock_cancel = MagicMock() mock_cancel = MagicMock()
@ -347,27 +367,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
class MockRegister(set, MagicMock): class MockRegister(set, MagicMock):
pass pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO)) self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(msg=FOO) mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock() self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.cancel_group(BAR) self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called() mock__cancel_and_remove_all_from_group.assert_not_called()
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR)) self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_group_reg = MagicMock() mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg} self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
self.assertIsNone(self.task_pool.cancel_all('msg')) self.assertIsNone(self.task_pool.cancel_all(BAZ))
mock__get_cancel_kw.assert_called_once_with(BAZ)
mock__cancel_and_remove_all_from_group.assert_has_calls([ mock__cancel_and_remove_all_from_group.assert_has_calls([
call(BAR, mock_group_reg, msg='msg'), call(BAR, mock_group_reg, **fake_cancel_kw),
call(FOO, mock_group_reg, msg='msg') call(FOO, mock_group_reg, **fake_cancel_kw)
]) ])
def test__pop_ended_meta_tasks(self): def test__pop_ended_meta_tasks(self):