Compare commits

..

13 Commits

17 changed files with 239 additions and 84 deletions

View File

@ -10,4 +10,3 @@ skip_covered = False
exclude_lines =
if TYPE_CHECKING:
if __name__ == ['"]__main__['"]:
if sys.version_info.+:

View File

@ -1,5 +1,7 @@
name: CI
on: [push]
on:
push:
branches: [master]
jobs:
tests:
name: Python ${{ matrix.python-version }} Tests

View File

@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags
release = '1.0.0-beta'
release = '1.1.4'
# -- General configuration ---------------------------------------------------

View File

@ -45,6 +45,7 @@ Contents
:maxdepth: 2
pages/pool
pages/ids
pages/control
api/api

42
docs/source/pages/ids.rst Normal file
View File

@ -0,0 +1,42 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
IDs, groups & names
===================
Task IDs
--------
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
Task groups
-----------
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
Pool names
----------
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.

View File

@ -87,7 +87,7 @@ By contrast, here is how you would do it with a task pool:
...
await pool.flush()
Pretty much self-explanatory, no?
Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 1.0.1
version = 1.1.4
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -35,7 +35,7 @@ __all__ = []
CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path'
SOCKET_PATH = 'socket_path'
HOST, PORT = 'host', 'port'

View File

@ -85,6 +85,7 @@ class ControlClient(ABC):
"""
self._connected = True
writer.write(json.dumps(self._client_info()).encode())
writer.write(b'\n')
await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
print("Type '-h' to get help and usage instructions for all available commands.\n")
@ -131,6 +132,7 @@ class ControlClient(ABC):
try:
# Send the command to the server.
writer.write(cmd.encode())
writer.write(b'\n')
await writer.drain()
except ConnectionError as e:
self._connected = False

View File

@ -32,7 +32,7 @@ from typing import Callable, Optional, Union, TYPE_CHECKING
from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
from ..internals.helpers import return_or_exception
if TYPE_CHECKING:
@ -103,7 +103,7 @@ class ControlSession:
elif param.kind == param.VAR_POSITIONAL:
var_pos = kwargs.pop(param.name)
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
self._writer.write(CMD_OK if output is None else str(output).encode())
self._response_buffer.write(CMD_OK.decode() if output is None else str(output))
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
"""
@ -122,10 +122,10 @@ class ControlSession:
if kwargs:
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK)
self._response_buffer.write(CMD_OK.decode())
else:
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
self._response_buffer.write(str(await return_or_exception(prop.fget, self._pool)))
async def client_handshake(self) -> None:
"""
@ -134,7 +134,8 @@ class ControlSession:
Client info is retrieved, server info is sent back, and the
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
"""
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
msg = (await self._reader.readline()).decode().strip()
client_info = json.loads(msg)
log.debug("%s connected", self._client_class_name)
parser_kwargs = {
'stream': self._response_buffer,
@ -146,7 +147,7 @@ class ControlSession:
self._parser.add_subparsers(title="Commands",
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
self._parser.add_class_commands(self._pool.__class__)
self._writer.write(str(self._pool).encode())
self._writer.write(str(self._pool).encode() + b'\n')
await self._writer.drain()
async def _parse_command(self, msg: str) -> None:
@ -187,12 +188,12 @@ class ControlSession:
It will obviously block indefinitely.
"""
while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
msg = (await self._reader.readline()).decode().strip()
if not msg:
log.debug("%s disconnected", self._client_class_name)
break
await self._parse_command(msg)
response = self._response_buffer.getvalue()
response = self._response_buffer.getvalue() + "\n"
self._response_buffer.seek(0)
self._response_buffer.truncate()
self._writer.write(response.encode())

View File

@ -21,8 +21,13 @@ This module should **not** be considered part of the public API.
"""
import sys
PACKAGE_NAME = 'asyncio_taskpool'
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100

View File

@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
import builtins
import sys
from asyncio.coroutines import iscoroutinefunction
from importlib import import_module
from inspect import getdoc
from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
# Starting with Python 3.9, this is thankfully no longer necessary.
if sys.version_info[:2] < (3, 9):
if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -28,16 +28,17 @@ For further details about the classes check their respective documentation.
import logging
import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore
from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task, gather
from contextlib import suppress
from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -71,7 +72,7 @@ class BaseTaskPool:
# Initialize flags; immutably set the name.
self._locked: bool = False
self._closed: bool = False
self._closed: Event = Event()
self._name: str = name
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
@ -220,7 +221,7 @@ class BaseTaskPool:
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if function and not iscoroutinefunction(function):
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
if self._closed:
if self._closed.is_set():
raise exceptions.PoolIsClosed("You must use another pool")
if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks")
@ -360,6 +361,23 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
@staticmethod
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
"""
Returns a dictionary to unpack in a `Task.cancel()` method.
This method exists to ensure proper compatibility with older Python versions.
If `msg` is `None`, an empty dictionary is returned.
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
Otherwise the keyword dictionary contains the `msg` parameter.
"""
if msg is None:
return {}
if PYTHON_BEFORE_39:
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
return {}
return {'msg': msg}
def cancel(self, *task_ids: int, msg: str = None) -> None:
"""
Cancels the tasks with the specified IDs.
@ -378,8 +396,9 @@ class BaseTaskPool:
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
"""
tasks = [self._get_running_task(task_id) for task_id in task_ids]
kw = self._get_cancel_kw(msg)
for task in tasks:
task.cancel(msg=msg)
task.cancel(**kw)
def _cancel_group_meta_tasks(self, group_name: str) -> None:
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
@ -392,7 +411,7 @@ class BaseTaskPool:
self._meta_tasks_cancelled.update(meta_tasks)
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
"""
Removes all tasks from the specified group and cancels them.
@ -406,7 +425,7 @@ class BaseTaskPool:
self._cancel_group_meta_tasks(group_name)
while group_reg:
try:
self._tasks_running[group_reg.pop()].cancel(msg=msg)
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
except KeyError:
continue
log.debug("%s cancelled tasks from group %s", str(self), group_name)
@ -433,7 +452,8 @@ class BaseTaskPool:
group_reg = self._task_groups.pop(group_name)
except KeyError:
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
kw = self._get_cancel_kw(msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
log.debug("%s forgot task group %s", str(self), group_name)
def cancel_all(self, msg: str = None) -> None:
@ -448,9 +468,10 @@ class BaseTaskPool:
msg (optional): Passed to the `Task.cancel()` method of every task.
"""
log.warning("%s cancelling all tasks!", str(self))
kw = self._get_cancel_kw(msg)
while self._task_groups:
group_name, group_reg = self._task_groups.popitem()
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
def _pop_ended_meta_tasks(self) -> Set[Task]:
"""
@ -529,9 +550,16 @@ class BaseTaskPool:
self._tasks_ended.clear()
self._tasks_cancelled.clear()
self._tasks_running.clear()
self._closed = True
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server.
self._closed.set()
async def until_closed(self) -> bool:
"""
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
Returns:
`True` once the pool is closed.
"""
return await self._closed.wait()
class TaskPool(BaseTaskPool):
@ -604,7 +632,7 @@ class TaskPool(BaseTaskPool):
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
continue
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback)
@ -737,9 +765,10 @@ class TaskPool(BaseTaskPool):
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
"""
Creates tasks in the pool with arguments from the supplied iterable.
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
The method is a task-based equivalent of the `multiprocessing.pool.Pool.map` method.
All the new tasks are added to the same task group.
@ -791,10 +820,10 @@ class TaskPool(BaseTaskPool):
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. The method is a task-based
equivalent of the `multiprocessing.pool.Pool.map` method.
All the new tasks are added to the same task group.
@ -848,6 +877,8 @@ class TaskPool(BaseTaskPool):
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
as positional arguments to the function.
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
@ -865,6 +896,8 @@ class TaskPool(BaseTaskPool):
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
"""
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
unpacked as keyword-arguments to the function.
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
@ -941,13 +974,24 @@ class SimpleTaskPool(BaseTaskPool):
async def _start_num(self, num: int, group_name: str) -> None:
"""Starts `num` new tasks in group `group_name`."""
start_coroutines = (
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num)
)
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
await gather(*start_coroutines)
for i in range(num):
try:
coroutine = self._func(*self._args, **self._kwargs)
except Exception as e:
# This means there was probably something wrong with the function arguments.
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
str(e.__class__.__name__), str(self), self._func.__name__,
repr(self._args), repr(self._kwargs))
continue # TODO: Consider returning instead of continuing
try:
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
cancel_callback=self._cancel_callback)
except CancelledError:
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
# more tasks and can return immediately.
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
coroutine.close()
return
def start(self, num: int) -> str:
"""

View File

@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected)
mock__client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_has_calls([
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
self.mock_drain.side_effect = err = ConnectionError()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_not_awaited()
self.mock_print.assert_called_once_with(err, file=sys.stderr)
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.mock_print.reset_mock()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode())
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with(FOO)

View File

@ -26,7 +26,7 @@ from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool
@ -74,7 +74,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
@patch.object(session, 'return_or_exception')
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
@ -85,15 +85,16 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_return_or_exception.return_value = val = 420.69
self.assertIsNone(await self.session._exec_property_and_respond(prop))
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
self.mock_writer.write.assert_called_once_with(str(val).encode())
self.assertEqual(str(val), self.session._response_buffer.getvalue())
@patch.object(session, 'ControlParser')
async def test_client_handshake(self, mock_parser_cls: MagicMock):
@ -102,8 +103,8 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parser_cls.return_value = mock_parser
width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read
mock_readline = AsyncMock(return_value=msg.encode())
self.mock_reader.readline = mock_readline
self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = {
'stream': self.session._response_buffer,
@ -117,11 +118,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
}
self.assertIsNone(await self.session.client_handshake())
self.assertEqual(mock_parser, self.session._parser)
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
mock_readline.assert_awaited_once_with()
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session.ControlSession, '_exec_property_and_respond')
@ -190,27 +191,27 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
@patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty():
self.mock_reader.read.return_value = b''
self.mock_reader.readline.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO
self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
self.mock_reader.readline.assert_has_awaits([call(), call()])
mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode())
self.mock_writer.write.assert_called_once_with(response.encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock()
self.mock_reader.readline.reset_mock()
mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited()
self.mock_reader.readline.assert_not_awaited()
mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited()

View File

@ -18,10 +18,11 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module.
"""
import importlib
from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers
@ -152,3 +153,15 @@ class ClassMethodWorkaroundTestCase(TestCase):
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
"""
from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore
from asyncio.locks import Event, Semaphore
from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(0, self.task_pool._num_started)
self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed)
self.assertIsInstance(self.task_pool._closed, Event)
self.assertFalse(self.task_pool._closed.is_set())
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self):
self.task_pool._closed = True
self.task_pool._closed.set()
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try:
with self.assertRaises(AssertionError):
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False
self.task_pool._closed.clear()
self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
@ -311,13 +312,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -336,6 +356,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555
mock_cancel = MagicMock()
@ -347,27 +368,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
class MockRegister(set, MagicMock):
pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
mock_cancel.assert_called_once_with(msg=FOO)
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called()
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
self.assertIsNone(self.task_pool.cancel_all('msg'))
self.assertIsNone(self.task_pool.cancel_all(BAZ))
mock__get_cancel_kw.assert_called_once_with(BAZ)
mock__cancel_and_remove_all_from_group.assert_has_calls([
call(BAR, mock_group_reg, msg='msg'),
call(FOO, mock_group_reg, msg='msg')
call(BAR, mock_group_reg, **fake_cancel_kw),
call(FOO, mock_group_reg, **fake_cancel_kw)
])
def test__pop_ended_meta_tasks(self):
@ -435,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertTrue(self.task_pool._closed)
self.assertTrue(self.task_pool._closed.is_set())
async def test_until_closed(self):
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
output = await self.task_pool.until_closed()
self.assertEqual(FOO, output)
self.task_pool._closed.wait.assert_awaited_once_with()
class TaskPoolTestCase(CommonTestCase):
@ -764,18 +797,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
@patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_num(self, mock__start_task: AsyncMock):
fake_coroutine = object()
self.task_pool._func = MagicMock(return_value=fake_coroutine)
num = 3
group_name = FOO + BAR + 'abc'
mock_awaitable1, mock_awaitable2 = object(), object()
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
num = 3
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(num * [
call(*self.task_pool._args, **self.task_pool._kwargs)
])
mock__start_task.assert_has_awaits(num * [
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
])
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
call_kw = {
'group_name': group_name,
'end_callback': self.task_pool._end_callback,
'cancel_callback': self.task_pool._cancel_callback
}
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
self.task_pool._func.reset_mock(side_effect=True)
mock__start_task.reset_mock()
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())