Compare commits

...

3 Commits

9 changed files with 130 additions and 41 deletions

View File

@ -10,4 +10,3 @@ skip_covered = False
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == ['"]__main__['"]: if __name__ == ['"]__main__['"]:
if sys.version_info.+:

View File

@ -1,5 +1,7 @@
name: CI name: CI
on: [push] on:
push:
branches: [master]
jobs: jobs:
tests: tests:
name: Python ${{ matrix.python-version }} Tests name: Python ${{ matrix.python-version }} Tests

View File

@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg' author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '1.0.0-beta' release = '1.0.2'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 1.0.1 version = 1.0.2
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -21,8 +21,13 @@ This module should **not** be considered part of the public API.
""" """
import sys
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DEFAULT_TASK_GROUP = 'default' DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100

View File

@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
import builtins import builtins
import sys
from asyncio.coroutines import iscoroutinefunction from asyncio.coroutines import iscoroutinefunction
from importlib import import_module from importlib import import_module
from inspect import getdoc from inspect import getdoc
from typing import Any, Callable, Optional, Type, Union from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
# Starting with Python 3.9, this is thankfully no longer necessary. # Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9): if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround classmethod = ClassMethodWorkaround
else: else:
classmethod = builtins.classmethod classmethod = builtins.classmethod

View File

@ -28,6 +28,7 @@ For further details about the classes check their respective documentation.
import logging import logging
import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
@ -37,7 +38,7 @@ from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -360,6 +361,23 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@staticmethod
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
"""
Returns a dictionary to unpack in a `Task.cancel()` method.
This method exists to ensure proper compatibility with older Python versions.
If `msg` is `None`, an empty dictionary is returned.
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
Otherwise the keyword dictionary contains the `msg` parameter.
"""
if msg is None:
return {}
if PYTHON_BEFORE_39:
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
return {}
return {'msg': msg}
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
""" """
Cancels the tasks with the specified IDs. Cancels the tasks with the specified IDs.
@ -378,8 +396,9 @@ class BaseTaskPool:
`InvalidTaskID`: One of the `task_ids` is not known to the pool. `InvalidTaskID`: One of the `task_ids` is not known to the pool.
""" """
tasks = [self._get_running_task(task_id) for task_id in task_ids] tasks = [self._get_running_task(task_id) for task_id in task_ids]
kw = self._get_cancel_kw(msg)
for task in tasks: for task in tasks:
task.cancel(msg=msg) task.cancel(**kw)
def _cancel_group_meta_tasks(self, group_name: str) -> None: def _cancel_group_meta_tasks(self, group_name: str) -> None:
"""Cancels and forgets all meta tasks associated with the task group named `group_name`.""" """Cancels and forgets all meta tasks associated with the task group named `group_name`."""
@ -392,7 +411,7 @@ class BaseTaskPool:
self._meta_tasks_cancelled.update(meta_tasks) self._meta_tasks_cancelled.update(meta_tasks)
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name) log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None: def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
""" """
Removes all tasks from the specified group and cancels them. Removes all tasks from the specified group and cancels them.
@ -406,7 +425,7 @@ class BaseTaskPool:
self._cancel_group_meta_tasks(group_name) self._cancel_group_meta_tasks(group_name)
while group_reg: while group_reg:
try: try:
self._tasks_running[group_reg.pop()].cancel(msg=msg) self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
except KeyError: except KeyError:
continue continue
log.debug("%s cancelled tasks from group %s", str(self), group_name) log.debug("%s cancelled tasks from group %s", str(self), group_name)
@ -433,7 +452,8 @@ class BaseTaskPool:
group_reg = self._task_groups.pop(group_name) group_reg = self._task_groups.pop(group_name)
except KeyError: except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.") raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) kw = self._get_cancel_kw(msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
log.debug("%s forgot task group %s", str(self), group_name) log.debug("%s forgot task group %s", str(self), group_name)
def cancel_all(self, msg: str = None) -> None: def cancel_all(self, msg: str = None) -> None:
@ -448,9 +468,10 @@ class BaseTaskPool:
msg (optional): Passed to the `Task.cancel()` method of every task. msg (optional): Passed to the `Task.cancel()` method of every task.
""" """
log.warning("%s cancelling all tasks!", str(self)) log.warning("%s cancelling all tasks!", str(self))
kw = self._get_cancel_kw(msg)
while self._task_groups: while self._task_groups:
group_name, group_reg = self._task_groups.popitem() group_name, group_reg = self._task_groups.popitem()
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg) self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
def _pop_ended_meta_tasks(self) -> Set[Task]: def _pop_ended_meta_tasks(self) -> Set[Task]:
""" """
@ -604,7 +625,7 @@ class TaskPool(BaseTaskPool):
# This means there was probably something wrong with the function arguments. # This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)", log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs)) str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue continue # TODO: Consider returning instead of continuing
try: try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback, await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) cancel_callback=cancel_callback)
@ -941,13 +962,24 @@ class SimpleTaskPool(BaseTaskPool):
async def _start_num(self, num: int, group_name: str) -> None: async def _start_num(self, num: int, group_name: str) -> None:
"""Starts `num` new tasks in group `group_name`.""" """Starts `num` new tasks in group `group_name`."""
start_coroutines = ( for i in range(num):
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name, try:
end_callback=self._end_callback, cancel_callback=self._cancel_callback) coroutine = self._func(*self._args, **self._kwargs)
for _ in range(num) except Exception as e:
) # This means there was probably something wrong with the function arguments.
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling! log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
await gather(*start_coroutines) str(e.__class__.__name__), str(self), self._func.__name__,
repr(self._args), repr(self._kwargs))
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
cancel_callback=self._cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def start(self, num: int) -> str: def start(self, num: int) -> str:
""" """

View File

@ -18,10 +18,11 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module. Unittests for the `asyncio_taskpool.helpers` module.
""" """
import importlib
from unittest import IsolatedAsyncioTestCase, TestCase from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers from asyncio_taskpool.internals import helpers
@ -152,3 +153,15 @@ class ClassMethodWorkaroundTestCase(TestCase):
cls = None cls = None
output = instance.__get__(obj, cls) output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -311,13 +311,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task') @patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock): def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9 task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock() mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO)) self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_group_meta_tasks(self): def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock() mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -336,6 +355,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks') @patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock): def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555 task_id = 555
mock_cancel = MagicMock() mock_cancel = MagicMock()
@ -347,27 +367,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
class MockRegister(set, MagicMock): class MockRegister(set, MagicMock):
pass pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO)) self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(msg=FOO) mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock() self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.cancel_group(BAR) self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called() mock__cancel_and_remove_all_from_group.assert_not_called()
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR)) self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_group_reg = MagicMock() mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg} self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
self.assertIsNone(self.task_pool.cancel_all('msg')) self.assertIsNone(self.task_pool.cancel_all(BAZ))
mock__get_cancel_kw.assert_called_once_with(BAZ)
mock__cancel_and_remove_all_from_group.assert_has_calls([ mock__cancel_and_remove_all_from_group.assert_has_calls([
call(BAR, mock_group_reg, msg='msg'), call(BAR, mock_group_reg, **fake_cancel_kw),
call(FOO, mock_group_reg, msg='msg') call(FOO, mock_group_reg, **fake_cancel_kw)
]) ])
def test__pop_ended_meta_tasks(self): def test__pop_ended_meta_tasks(self):
@ -764,18 +790,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
@patch.object(pool.SimpleTaskPool, '_start_task') @patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_num(self, mock__start_task: AsyncMock): async def test__start_num(self, mock__start_task: AsyncMock):
fake_coroutine = object()
self.task_pool._func = MagicMock(return_value=fake_coroutine)
num = 3
group_name = FOO + BAR + 'abc' group_name = FOO + BAR + 'abc'
mock_awaitable1, mock_awaitable2 = object(), object()
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
num = 3
self.assertIsNone(await self.task_pool._start_num(num, group_name)) self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(num * [ self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
call(*self.task_pool._args, **self.task_pool._kwargs) call_kw = {
]) 'group_name': group_name,
mock__start_task.assert_has_awaits(num * [ 'end_callback': self.task_pool._end_callback,
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback, 'cancel_callback': self.task_pool._cancel_callback
cancel_callback=self.task_pool._cancel_callback) }
]) mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
self.task_pool._func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock()) @patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())