Fix cancel message bug for Python 3.8; test coverage workaround for Python version conditions

This commit is contained in:
Daniil Fajnberg 2022-04-10 10:41:32 +02:00
parent d7cd16c540
commit d047b99119
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
7 changed files with 87 additions and 21 deletions

View File

@ -10,4 +10,3 @@ skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
if sys.version_info.+:

View File

@ -1,5 +1,7 @@
name: CI
on: [push]
on:
push:
branches: [master]
jobs:
tests:
name: Python ${{ matrix.python-version }} Tests

View File

@ -21,8 +21,13 @@ This module should **not** be considered part of the public API.
"""
import sys
PACKAGE_NAME = 'asyncio_taskpool'
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100

View File

@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
import builtins
import sys
from asyncio.coroutines import iscoroutinefunction
from importlib import import_module
from inspect import getdoc
from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
# Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9):
if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -28,6 +28,7 @@ For further details about the classes check their respective documentation.
import logging
import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore
@ -37,7 +38,7 @@ from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -360,6 +361,23 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@staticmethod
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
"""
Returns a dictionary to unpack in a `Task.cancel()` method.
This method exists to ensure proper compatibility with older Python versions.
If `msg` is `None`, an empty dictionary is returned.
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
Otherwise the keyword dictionary contains the `msg` parameter.
"""
if msg is None:
return {}
if PYTHON_BEFORE_39:
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
return {}
return {'msg': msg}
def cancel(self, *task_ids: int, msg: str = None) -> None:
"""
Cancels the tasks with the specified IDs.
@ -378,8 +396,9 @@ class BaseTaskPool:
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids]
kw = self._get_cancel_kw(msg)
for task in tasks:
task.cancel(msg=msg)
task.cancel(**kw)
def _cancel_group_meta_tasks(self, group_name: str) -> None:
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
@ -392,7 +411,7 @@ class BaseTaskPool:
self._meta_tasks_cancelled.update(meta_tasks)
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
"""
Removes all tasks from the specified group and cancels them.
@ -406,7 +425,7 @@ class BaseTaskPool:
self._cancel_group_meta_tasks(group_name)
while group_reg:
try:
self._tasks_running[group_reg.pop()].cancel(msg=msg)
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
except KeyError:
continue
log.debug("%s cancelled tasks from group %s", str(self), group_name)
@ -433,7 +452,8 @@ class BaseTaskPool:
group_reg = self._task_groups.pop(group_name)
except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
kw = self._get_cancel_kw(msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
log.debug("%s forgot task group %s", str(self), group_name)
def cancel_all(self, msg: str = None) -> None:
@ -448,9 +468,10 @@ class BaseTaskPool:
msg (optional): Passed to the `Task.cancel()` method of every task.
"""
log.warning("%s cancelling all tasks!", str(self))
kw = self._get_cancel_kw(msg)
while self._task_groups:
group_name, group_reg = self._task_groups.popitem()
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
def _pop_ended_meta_tasks(self) -> Set[Task]:
"""

View File

@ -18,10 +18,11 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module.
"""
import importlib
from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers
@ -152,3 +153,15 @@ class ClassMethodWorkaroundTestCase(TestCase):
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -311,13 +311,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -336,6 +355,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555
mock_cancel = MagicMock()
@ -347,27 +367,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
class MockRegister(set, MagicMock):
pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
mock_cancel.assert_called_once_with(msg=FOO)
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called()
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
self.assertIsNone(self.task_pool.cancel_all('msg'))
self.assertIsNone(self.task_pool.cancel_all(BAZ))
mock__get_cancel_kw.assert_called_once_with(BAZ)
mock__cancel_and_remove_all_from_group.assert_has_calls([
call(BAR, mock_group_reg, msg='msg'),
call(FOO, mock_group_reg, msg='msg')
call(BAR, mock_group_reg, **fake_cancel_kw),
call(FOO, mock_group_reg, **fake_cancel_kw)
])
def test__pop_ended_meta_tasks(self):