generated from daniil-berg/boilerplate-py
Compare commits
No commits in common. "master" and "v1.0.1" have entirely different histories.
@ -10,3 +10,4 @@ skip_covered = False
|
|||||||
exclude_lines =
|
exclude_lines =
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if __name__ == ['"]__main__['"]:
|
if __name__ == ['"]__main__['"]:
|
||||||
|
if sys.version_info.+:
|
||||||
|
4
.github/workflows/main.yaml
vendored
4
.github/workflows/main.yaml
vendored
@ -1,7 +1,5 @@
|
|||||||
name: CI
|
name: CI
|
||||||
on:
|
on: [push]
|
||||||
push:
|
|
||||||
branches: [master]
|
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
name: Python ${{ matrix.python-version }} Tests
|
name: Python ${{ matrix.python-version }} Tests
|
||||||
|
@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
|
|||||||
author = 'Daniil Fajnberg'
|
author = 'Daniil Fajnberg'
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '1.1.4'
|
release = '1.0.0-beta'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
@ -45,7 +45,6 @@ Contents
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
pages/pool
|
pages/pool
|
||||||
pages/ids
|
|
||||||
pages/control
|
pages/control
|
||||||
api/api
|
api/api
|
||||||
|
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
.. This file is part of asyncio-taskpool.
|
|
||||||
|
|
||||||
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
|
||||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
|
||||||
|
|
||||||
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
|
||||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
See the GNU Lesser General Public License for more details.
|
|
||||||
|
|
||||||
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
|
||||||
If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
.. Copyright © 2022 Daniil Fajnberg
|
|
||||||
|
|
||||||
|
|
||||||
IDs, groups & names
|
|
||||||
===================
|
|
||||||
|
|
||||||
Task IDs
|
|
||||||
--------
|
|
||||||
|
|
||||||
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
|
|
||||||
|
|
||||||
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
|
|
||||||
|
|
||||||
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
|
|
||||||
|
|
||||||
Task groups
|
|
||||||
-----------
|
|
||||||
|
|
||||||
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
|
|
||||||
|
|
||||||
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
|
|
||||||
|
|
||||||
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
|
|
||||||
|
|
||||||
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
|
|
||||||
|
|
||||||
Pool names
|
|
||||||
----------
|
|
||||||
|
|
||||||
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.
|
|
@ -87,7 +87,7 @@ By contrast, here is how you would do it with a task pool:
|
|||||||
...
|
...
|
||||||
await pool.flush()
|
await pool.flush()
|
||||||
|
|
||||||
Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
|
Pretty much self-explanatory, no?
|
||||||
|
|
||||||
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 1.1.4
|
version = 1.0.1
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -35,7 +35,7 @@ __all__ = []
|
|||||||
|
|
||||||
CLIENT_CLASS = 'client_class'
|
CLIENT_CLASS = 'client_class'
|
||||||
UNIX, TCP = 'unix', 'tcp'
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
SOCKET_PATH = 'socket_path'
|
SOCKET_PATH = 'path'
|
||||||
HOST, PORT = 'host', 'port'
|
HOST, PORT = 'host', 'port'
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,6 @@ class ControlClient(ABC):
|
|||||||
"""
|
"""
|
||||||
self._connected = True
|
self._connected = True
|
||||||
writer.write(json.dumps(self._client_info()).encode())
|
writer.write(json.dumps(self._client_info()).encode())
|
||||||
writer.write(b'\n')
|
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
||||||
@ -132,7 +131,6 @@ class ControlClient(ABC):
|
|||||||
try:
|
try:
|
||||||
# Send the command to the server.
|
# Send the command to the server.
|
||||||
writer.write(cmd.encode())
|
writer.write(cmd.encode())
|
||||||
writer.write(b'\n')
|
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
@ -32,7 +32,7 @@ from typing import Callable, Optional, Union, TYPE_CHECKING
|
|||||||
from .parser import ControlParser
|
from .parser import ControlParser
|
||||||
from ..exceptions import CommandError, HelpRequested, ParserError
|
from ..exceptions import CommandError, HelpRequested, ParserError
|
||||||
from ..pool import TaskPool, SimpleTaskPool
|
from ..pool import TaskPool, SimpleTaskPool
|
||||||
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
|
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES
|
||||||
from ..internals.helpers import return_or_exception
|
from ..internals.helpers import return_or_exception
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -103,7 +103,7 @@ class ControlSession:
|
|||||||
elif param.kind == param.VAR_POSITIONAL:
|
elif param.kind == param.VAR_POSITIONAL:
|
||||||
var_pos = kwargs.pop(param.name)
|
var_pos = kwargs.pop(param.name)
|
||||||
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
|
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
|
||||||
self._response_buffer.write(CMD_OK.decode() if output is None else str(output))
|
self._writer.write(CMD_OK if output is None else str(output).encode())
|
||||||
|
|
||||||
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
|
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
@ -122,10 +122,10 @@ class ControlSession:
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
||||||
await return_or_exception(prop.fset, self._pool, **kwargs)
|
await return_or_exception(prop.fset, self._pool, **kwargs)
|
||||||
self._response_buffer.write(CMD_OK.decode())
|
self._writer.write(CMD_OK)
|
||||||
else:
|
else:
|
||||||
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
||||||
self._response_buffer.write(str(await return_or_exception(prop.fget, self._pool)))
|
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
|
||||||
|
|
||||||
async def client_handshake(self) -> None:
|
async def client_handshake(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -134,8 +134,7 @@ class ControlSession:
|
|||||||
Client info is retrieved, server info is sent back, and the
|
Client info is retrieved, server info is sent back, and the
|
||||||
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
||||||
"""
|
"""
|
||||||
msg = (await self._reader.readline()).decode().strip()
|
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
||||||
client_info = json.loads(msg)
|
|
||||||
log.debug("%s connected", self._client_class_name)
|
log.debug("%s connected", self._client_class_name)
|
||||||
parser_kwargs = {
|
parser_kwargs = {
|
||||||
'stream': self._response_buffer,
|
'stream': self._response_buffer,
|
||||||
@ -147,7 +146,7 @@ class ControlSession:
|
|||||||
self._parser.add_subparsers(title="Commands",
|
self._parser.add_subparsers(title="Commands",
|
||||||
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
||||||
self._parser.add_class_commands(self._pool.__class__)
|
self._parser.add_class_commands(self._pool.__class__)
|
||||||
self._writer.write(str(self._pool).encode() + b'\n')
|
self._writer.write(str(self._pool).encode())
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
|
||||||
async def _parse_command(self, msg: str) -> None:
|
async def _parse_command(self, msg: str) -> None:
|
||||||
@ -188,12 +187,12 @@ class ControlSession:
|
|||||||
It will obviously block indefinitely.
|
It will obviously block indefinitely.
|
||||||
"""
|
"""
|
||||||
while self._control_server.is_serving():
|
while self._control_server.is_serving():
|
||||||
msg = (await self._reader.readline()).decode().strip()
|
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
||||||
if not msg:
|
if not msg:
|
||||||
log.debug("%s disconnected", self._client_class_name)
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
break
|
break
|
||||||
await self._parse_command(msg)
|
await self._parse_command(msg)
|
||||||
response = self._response_buffer.getvalue() + "\n"
|
response = self._response_buffer.getvalue()
|
||||||
self._response_buffer.seek(0)
|
self._response_buffer.seek(0)
|
||||||
self._response_buffer.truncate()
|
self._response_buffer.truncate()
|
||||||
self._writer.write(response.encode())
|
self._writer.write(response.encode())
|
||||||
|
@ -21,13 +21,8 @@ This module should **not** be considered part of the public API.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
|
|
||||||
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
|
|
||||||
|
|
||||||
DEFAULT_TASK_GROUP = 'default'
|
DEFAULT_TASK_GROUP = 'default'
|
||||||
|
|
||||||
SESSION_MSG_BYTES = 1024 * 100
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
@ -20,12 +20,12 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
|||||||
|
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import sys
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from inspect import getdoc
|
from inspect import getdoc
|
||||||
from typing import Any, Callable, Optional, Type, Union
|
from typing import Any, Callable, Optional, Type, Union
|
||||||
|
|
||||||
from .constants import PYTHON_BEFORE_39
|
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class ClassMethodWorkaround:
|
|||||||
|
|
||||||
|
|
||||||
# Starting with Python 3.9, this is thankfully no longer necessary.
|
# Starting with Python 3.9, this is thankfully no longer necessary.
|
||||||
if PYTHON_BEFORE_39:
|
if sys.version_info[:2] < (3, 9):
|
||||||
classmethod = ClassMethodWorkaround
|
classmethod = ClassMethodWorkaround
|
||||||
else:
|
else:
|
||||||
classmethod = builtins.classmethod
|
classmethod = builtins.classmethod
|
||||||
|
@ -28,17 +28,16 @@ For further details about the classes check their respective documentation.
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Event, Semaphore
|
from asyncio.locks import Semaphore
|
||||||
from asyncio.tasks import Task, create_task, gather
|
from asyncio.tasks import Task, create_task, gather
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from math import inf
|
from math import inf
|
||||||
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
|
from .internals.constants import DEFAULT_TASK_GROUP
|
||||||
from .internals.group_register import TaskGroupRegister
|
from .internals.group_register import TaskGroupRegister
|
||||||
from .internals.helpers import execute_optional, star_function
|
from .internals.helpers import execute_optional, star_function
|
||||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||||
@ -72,7 +71,7 @@ class BaseTaskPool:
|
|||||||
|
|
||||||
# Initialize flags; immutably set the name.
|
# Initialize flags; immutably set the name.
|
||||||
self._locked: bool = False
|
self._locked: bool = False
|
||||||
self._closed: Event = Event()
|
self._closed: bool = False
|
||||||
self._name: str = name
|
self._name: str = name
|
||||||
|
|
||||||
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
||||||
@ -221,7 +220,7 @@ class BaseTaskPool:
|
|||||||
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
||||||
if function and not iscoroutinefunction(function):
|
if function and not iscoroutinefunction(function):
|
||||||
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
||||||
if self._closed.is_set():
|
if self._closed:
|
||||||
raise exceptions.PoolIsClosed("You must use another pool")
|
raise exceptions.PoolIsClosed("You must use another pool")
|
||||||
if self._locked and not ignore_lock:
|
if self._locked and not ignore_lock:
|
||||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||||
@ -361,23 +360,6 @@ class BaseTaskPool:
|
|||||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Returns a dictionary to unpack in a `Task.cancel()` method.
|
|
||||||
|
|
||||||
This method exists to ensure proper compatibility with older Python versions.
|
|
||||||
If `msg` is `None`, an empty dictionary is returned.
|
|
||||||
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
|
|
||||||
Otherwise the keyword dictionary contains the `msg` parameter.
|
|
||||||
"""
|
|
||||||
if msg is None:
|
|
||||||
return {}
|
|
||||||
if PYTHON_BEFORE_39:
|
|
||||||
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
|
|
||||||
return {}
|
|
||||||
return {'msg': msg}
|
|
||||||
|
|
||||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels the tasks with the specified IDs.
|
Cancels the tasks with the specified IDs.
|
||||||
@ -396,9 +378,8 @@ class BaseTaskPool:
|
|||||||
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
||||||
"""
|
"""
|
||||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||||
kw = self._get_cancel_kw(msg)
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel(**kw)
|
task.cancel(msg=msg)
|
||||||
|
|
||||||
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
||||||
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
||||||
@ -411,7 +392,7 @@ class BaseTaskPool:
|
|||||||
self._meta_tasks_cancelled.update(meta_tasks)
|
self._meta_tasks_cancelled.update(meta_tasks)
|
||||||
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
||||||
|
|
||||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
|
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Removes all tasks from the specified group and cancels them.
|
Removes all tasks from the specified group and cancels them.
|
||||||
|
|
||||||
@ -425,7 +406,7 @@ class BaseTaskPool:
|
|||||||
self._cancel_group_meta_tasks(group_name)
|
self._cancel_group_meta_tasks(group_name)
|
||||||
while group_reg:
|
while group_reg:
|
||||||
try:
|
try:
|
||||||
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
|
self._tasks_running[group_reg.pop()].cancel(msg=msg)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
||||||
@ -452,8 +433,7 @@ class BaseTaskPool:
|
|||||||
group_reg = self._task_groups.pop(group_name)
|
group_reg = self._task_groups.pop(group_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
||||||
kw = self._get_cancel_kw(msg)
|
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
|
||||||
log.debug("%s forgot task group %s", str(self), group_name)
|
log.debug("%s forgot task group %s", str(self), group_name)
|
||||||
|
|
||||||
def cancel_all(self, msg: str = None) -> None:
|
def cancel_all(self, msg: str = None) -> None:
|
||||||
@ -468,10 +448,9 @@ class BaseTaskPool:
|
|||||||
msg (optional): Passed to the `Task.cancel()` method of every task.
|
msg (optional): Passed to the `Task.cancel()` method of every task.
|
||||||
"""
|
"""
|
||||||
log.warning("%s cancelling all tasks!", str(self))
|
log.warning("%s cancelling all tasks!", str(self))
|
||||||
kw = self._get_cancel_kw(msg)
|
|
||||||
while self._task_groups:
|
while self._task_groups:
|
||||||
group_name, group_reg = self._task_groups.popitem()
|
group_name, group_reg = self._task_groups.popitem()
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
||||||
|
|
||||||
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
||||||
"""
|
"""
|
||||||
@ -550,16 +529,9 @@ class BaseTaskPool:
|
|||||||
self._tasks_ended.clear()
|
self._tasks_ended.clear()
|
||||||
self._tasks_cancelled.clear()
|
self._tasks_cancelled.clear()
|
||||||
self._tasks_running.clear()
|
self._tasks_running.clear()
|
||||||
self._closed.set()
|
self._closed = True
|
||||||
|
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
|
||||||
async def until_closed(self) -> bool:
|
# await it to allow blocking until a closing command comes from a server.
|
||||||
"""
|
|
||||||
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`True` once the pool is closed.
|
|
||||||
"""
|
|
||||||
return await self._closed.wait()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskPool(BaseTaskPool):
|
class TaskPool(BaseTaskPool):
|
||||||
@ -632,7 +604,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
# This means there was probably something wrong with the function arguments.
|
# This means there was probably something wrong with the function arguments.
|
||||||
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
||||||
continue # TODO: Consider returning instead of continuing
|
continue
|
||||||
try:
|
try:
|
||||||
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
||||||
cancel_callback=cancel_callback)
|
cancel_callback=cancel_callback)
|
||||||
@ -765,10 +737,9 @@ class TaskPool(BaseTaskPool):
|
|||||||
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||||
"""
|
"""
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
Creates tasks in the pool with arguments from the supplied iterable.
|
||||||
|
|
||||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
||||||
The method is a task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
@ -820,10 +791,10 @@ class TaskPool(BaseTaskPool):
|
|||||||
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||||
|
|
||||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. The method is a task-based
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
equivalent of the `multiprocessing.pool.Pool.map` method.
|
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
@ -877,8 +848,6 @@ class TaskPool(BaseTaskPool):
|
|||||||
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
|
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
||||||
as positional arguments to the function.
|
as positional arguments to the function.
|
||||||
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
||||||
@ -896,8 +865,6 @@ class TaskPool(BaseTaskPool):
|
|||||||
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
||||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
||||||
unpacked as keyword-arguments to the function.
|
unpacked as keyword-arguments to the function.
|
||||||
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
||||||
@ -974,24 +941,13 @@ class SimpleTaskPool(BaseTaskPool):
|
|||||||
|
|
||||||
async def _start_num(self, num: int, group_name: str) -> None:
|
async def _start_num(self, num: int, group_name: str) -> None:
|
||||||
"""Starts `num` new tasks in group `group_name`."""
|
"""Starts `num` new tasks in group `group_name`."""
|
||||||
for i in range(num):
|
start_coroutines = (
|
||||||
try:
|
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
|
||||||
coroutine = self._func(*self._args, **self._kwargs)
|
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
||||||
except Exception as e:
|
for _ in range(num)
|
||||||
# This means there was probably something wrong with the function arguments.
|
)
|
||||||
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
|
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
|
||||||
str(e.__class__.__name__), str(self), self._func.__name__,
|
await gather(*start_coroutines)
|
||||||
repr(self._args), repr(self._kwargs))
|
|
||||||
continue # TODO: Consider returning instead of continuing
|
|
||||||
try:
|
|
||||||
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
|
|
||||||
cancel_callback=self._cancel_callback)
|
|
||||||
except CancelledError:
|
|
||||||
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
|
||||||
# more tasks and can return immediately.
|
|
||||||
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
|
||||||
coroutine.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
def start(self, num: int) -> str:
|
def start(self, num: int) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||||
self.assertTrue(self.client._connected)
|
self.assertTrue(self.client._connected)
|
||||||
mock__client_info.assert_called_once_with()
|
mock__client_info.assert_called_once_with()
|
||||||
self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
|
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_has_calls([
|
self.mock_print.assert_has_calls([
|
||||||
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||||
self.mock_drain.side_effect = err = ConnectionError()
|
self.mock_drain.side_effect = err = ConnectionError()
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_not_awaited()
|
self.mock_read.assert_not_awaited()
|
||||||
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||||
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.mock_print.reset_mock()
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
self.mock_write.assert_called_once_with(cmd.encode())
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_called_once_with(FOO)
|
self.mock_print.assert_called_once_with(FOO)
|
||||||
|
@ -26,7 +26,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||||
|
|
||||||
from asyncio_taskpool.control import session
|
from asyncio_taskpool.control import session
|
||||||
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
|
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES
|
||||||
from asyncio_taskpool.exceptions import HelpRequested
|
from asyncio_taskpool.exceptions import HelpRequested
|
||||||
from asyncio_taskpool.pool import SimpleTaskPool
|
from asyncio_taskpool.pool import SimpleTaskPool
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_return_or_exception.assert_awaited_once_with(
|
mock_return_or_exception.assert_awaited_once_with(
|
||||||
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
||||||
)
|
)
|
||||||
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
|
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
|
||||||
|
|
||||||
@patch.object(session, 'return_or_exception')
|
@patch.object(session, 'return_or_exception')
|
||||||
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
|
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
|
||||||
@ -85,16 +85,15 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_return_or_exception.return_value = None
|
mock_return_or_exception.return_value = None
|
||||||
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
|
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
|
||||||
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
|
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
|
||||||
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
|
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
|
||||||
|
|
||||||
mock_return_or_exception.reset_mock()
|
mock_return_or_exception.reset_mock()
|
||||||
self.session._response_buffer.seek(0)
|
self.mock_writer.write.reset_mock()
|
||||||
self.session._response_buffer.truncate()
|
|
||||||
|
|
||||||
mock_return_or_exception.return_value = val = 420.69
|
mock_return_or_exception.return_value = val = 420.69
|
||||||
self.assertIsNone(await self.session._exec_property_and_respond(prop))
|
self.assertIsNone(await self.session._exec_property_and_respond(prop))
|
||||||
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
|
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
|
||||||
self.assertEqual(str(val), self.session._response_buffer.getvalue())
|
self.mock_writer.write.assert_called_once_with(str(val).encode())
|
||||||
|
|
||||||
@patch.object(session, 'ControlParser')
|
@patch.object(session, 'ControlParser')
|
||||||
async def test_client_handshake(self, mock_parser_cls: MagicMock):
|
async def test_client_handshake(self, mock_parser_cls: MagicMock):
|
||||||
@ -103,8 +102,8 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parser_cls.return_value = mock_parser
|
mock_parser_cls.return_value = mock_parser
|
||||||
width = 5678
|
width = 5678
|
||||||
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
||||||
mock_readline = AsyncMock(return_value=msg.encode())
|
mock_read = AsyncMock(return_value=msg.encode())
|
||||||
self.mock_reader.readline = mock_readline
|
self.mock_reader.read = mock_read
|
||||||
self.mock_writer.drain = AsyncMock()
|
self.mock_writer.drain = AsyncMock()
|
||||||
expected_parser_kwargs = {
|
expected_parser_kwargs = {
|
||||||
'stream': self.session._response_buffer,
|
'stream': self.session._response_buffer,
|
||||||
@ -118,11 +117,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
}
|
}
|
||||||
self.assertIsNone(await self.session.client_handshake())
|
self.assertIsNone(await self.session.client_handshake())
|
||||||
self.assertEqual(mock_parser, self.session._parser)
|
self.assertEqual(mock_parser, self.session._parser)
|
||||||
mock_readline.assert_awaited_once_with()
|
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
||||||
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
||||||
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode() + b'\n')
|
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
||||||
@ -191,27 +190,27 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
@patch.object(session.ControlSession, '_parse_command')
|
@patch.object(session.ControlSession, '_parse_command')
|
||||||
async def test_listen(self, mock__parse_command: AsyncMock):
|
async def test_listen(self, mock__parse_command: AsyncMock):
|
||||||
def make_reader_return_empty():
|
def make_reader_return_empty():
|
||||||
self.mock_reader.readline.return_value = b''
|
self.mock_reader.read.return_value = b''
|
||||||
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
||||||
msg = "fascinating"
|
msg = "fascinating"
|
||||||
self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
|
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
|
||||||
response = FOO + BAR + FOO
|
response = FOO + BAR + FOO
|
||||||
self.session._response_buffer.write(response)
|
self.session._response_buffer.write(response)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.readline.assert_has_awaits([call(), call()])
|
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
|
||||||
mock__parse_command.assert_awaited_once_with(msg)
|
mock__parse_command.assert_awaited_once_with(msg)
|
||||||
self.assertEqual('', self.session._response_buffer.getvalue())
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
self.mock_writer.write.assert_called_once_with(response.encode() + b'\n')
|
self.mock_writer.write.assert_called_once_with(response.encode())
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
self.mock_reader.readline.reset_mock()
|
self.mock_reader.read.reset_mock()
|
||||||
mock__parse_command.reset_mock()
|
mock__parse_command.reset_mock()
|
||||||
self.mock_writer.write.reset_mock()
|
self.mock_writer.write.reset_mock()
|
||||||
self.mock_writer.drain.reset_mock()
|
self.mock_writer.drain.reset_mock()
|
||||||
|
|
||||||
self.mock_server.is_serving = MagicMock(return_value=False)
|
self.mock_server.is_serving = MagicMock(return_value=False)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.readline.assert_not_awaited()
|
self.mock_reader.read.assert_not_awaited()
|
||||||
mock__parse_command.assert_not_awaited()
|
mock__parse_command.assert_not_awaited()
|
||||||
self.mock_writer.write.assert_not_called()
|
self.mock_writer.write.assert_not_called()
|
||||||
self.mock_writer.drain.assert_not_awaited()
|
self.mock_writer.drain.assert_not_awaited()
|
||||||
|
@ -18,11 +18,10 @@ __doc__ = """
|
|||||||
Unittests for the `asyncio_taskpool.helpers` module.
|
Unittests for the `asyncio_taskpool.helpers` module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
from unittest import IsolatedAsyncioTestCase, TestCase
|
from unittest import IsolatedAsyncioTestCase, TestCase
|
||||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
||||||
|
|
||||||
from asyncio_taskpool.internals import constants
|
|
||||||
from asyncio_taskpool.internals import helpers
|
from asyncio_taskpool.internals import helpers
|
||||||
|
|
||||||
|
|
||||||
@ -153,15 +152,3 @@ class ClassMethodWorkaroundTestCase(TestCase):
|
|||||||
cls = None
|
cls = None
|
||||||
output = instance.__get__(obj, cls)
|
output = instance.__get__(obj, cls)
|
||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
def test_correct_class(self):
|
|
||||||
is_older_python = constants.PYTHON_BEFORE_39
|
|
||||||
try:
|
|
||||||
constants.PYTHON_BEFORE_39 = True
|
|
||||||
importlib.reload(helpers)
|
|
||||||
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
|
|
||||||
constants.PYTHON_BEFORE_39 = False
|
|
||||||
importlib.reload(helpers)
|
|
||||||
self.assertIs(classmethod, helpers.classmethod)
|
|
||||||
finally:
|
|
||||||
constants.PYTHON_BEFORE_39 = is_older_python
|
|
||||||
|
@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Event, Semaphore
|
from asyncio.locks import Semaphore
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||||
from typing import Type
|
from typing import Type
|
||||||
@ -83,8 +83,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(0, self.task_pool._num_started)
|
self.assertEqual(0, self.task_pool._num_started)
|
||||||
|
|
||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
self.assertIsInstance(self.task_pool._closed, Event)
|
self.assertFalse(self.task_pool._closed)
|
||||||
self.assertFalse(self.task_pool._closed.is_set())
|
|
||||||
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
||||||
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
@ -163,7 +162,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool.get_group_ids(group_name, 'something else')
|
self.task_pool.get_group_ids(group_name, 'something else')
|
||||||
|
|
||||||
async def test__check_start(self):
|
async def test__check_start(self):
|
||||||
self.task_pool._closed.set()
|
self.task_pool._closed = True
|
||||||
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
||||||
try:
|
try:
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
@ -176,7 +175,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
||||||
with self.assertRaises(exceptions.PoolIsClosed):
|
with self.assertRaises(exceptions.PoolIsClosed):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
||||||
self.task_pool._closed.clear()
|
self.task_pool._closed = False
|
||||||
self.task_pool._locked = True
|
self.task_pool._locked = True
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
with self.assertRaises(exceptions.PoolIsLocked):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
||||||
@ -312,32 +311,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_not_called()
|
mock__task_name.assert_not_called()
|
||||||
|
|
||||||
@patch('warnings.warn')
|
|
||||||
def test__get_cancel_kw(self, mock_warn: MagicMock):
|
|
||||||
msg = None
|
|
||||||
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
|
||||||
mock_warn.assert_not_called()
|
|
||||||
|
|
||||||
msg = 'something'
|
|
||||||
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
|
|
||||||
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
|
||||||
mock_warn.assert_called_once()
|
|
||||||
mock_warn.reset_mock()
|
|
||||||
|
|
||||||
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
|
|
||||||
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
|
|
||||||
mock_warn.assert_not_called()
|
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
|
||||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||||
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
|
def test_cancel(self, mock__get_running_task: MagicMock):
|
||||||
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
|
||||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||||
mock__get_cancel_kw.assert_called_once_with(FOO)
|
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
||||||
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
|
|
||||||
|
|
||||||
def test__cancel_group_meta_tasks(self):
|
def test__cancel_group_meta_tasks(self):
|
||||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||||
@ -356,7 +336,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
||||||
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
||||||
kw = {BAR: 10, BAZ: 20}
|
|
||||||
task_id = 555
|
task_id = 555
|
||||||
mock_cancel = MagicMock()
|
mock_cancel = MagicMock()
|
||||||
|
|
||||||
@ -368,33 +347,27 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
class MockRegister(set, MagicMock):
|
class MockRegister(set, MagicMock):
|
||||||
pass
|
pass
|
||||||
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
|
||||||
mock_cancel.assert_called_once_with(**kw)
|
mock_cancel.assert_called_once_with(msg=FOO)
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
||||||
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
|
||||||
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
||||||
with self.assertRaises(exceptions.InvalidGroupName):
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
self.task_pool.cancel_group(BAR)
|
self.task_pool.cancel_group(BAR)
|
||||||
mock__cancel_and_remove_all_from_group.assert_not_called()
|
mock__cancel_and_remove_all_from_group.assert_not_called()
|
||||||
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
mock__get_cancel_kw.assert_called_once_with(BAR)
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
|
||||||
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
|
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
||||||
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
|
||||||
mock_group_reg = MagicMock()
|
mock_group_reg = MagicMock()
|
||||||
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
||||||
self.assertIsNone(self.task_pool.cancel_all(BAZ))
|
self.assertIsNone(self.task_pool.cancel_all('msg'))
|
||||||
mock__get_cancel_kw.assert_called_once_with(BAZ)
|
|
||||||
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
||||||
call(BAR, mock_group_reg, **fake_cancel_kw),
|
call(BAR, mock_group_reg, msg='msg'),
|
||||||
call(FOO, mock_group_reg, **fake_cancel_kw)
|
call(FOO, mock_group_reg, msg='msg')
|
||||||
])
|
])
|
||||||
|
|
||||||
def test__pop_ended_meta_tasks(self):
|
def test__pop_ended_meta_tasks(self):
|
||||||
@ -462,13 +435,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
self.assertTrue(self.task_pool._closed.is_set())
|
self.assertTrue(self.task_pool._closed)
|
||||||
|
|
||||||
async def test_until_closed(self):
|
|
||||||
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
|
|
||||||
output = await self.task_pool.until_closed()
|
|
||||||
self.assertEqual(FOO, output)
|
|
||||||
self.task_pool._closed.wait.assert_awaited_once_with()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskPoolTestCase(CommonTestCase):
|
class TaskPoolTestCase(CommonTestCase):
|
||||||
@ -797,30 +764,18 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_task')
|
@patch.object(pool.SimpleTaskPool, '_start_task')
|
||||||
async def test__start_num(self, mock__start_task: AsyncMock):
|
async def test__start_num(self, mock__start_task: AsyncMock):
|
||||||
group_name = FOO + BAR + 'abc'
|
fake_coroutine = object()
|
||||||
mock_awaitable1, mock_awaitable2 = object(), object()
|
self.task_pool._func = MagicMock(return_value=fake_coroutine)
|
||||||
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
|
||||||
num = 3
|
num = 3
|
||||||
|
group_name = FOO + BAR + 'abc'
|
||||||
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
self.task_pool._func.assert_has_calls(num * [
|
||||||
call_kw = {
|
call(*self.task_pool._args, **self.task_pool._kwargs)
|
||||||
'group_name': group_name,
|
])
|
||||||
'end_callback': self.task_pool._end_callback,
|
mock__start_task.assert_has_awaits(num * [
|
||||||
'cancel_callback': self.task_pool._cancel_callback
|
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
|
||||||
}
|
cancel_callback=self.task_pool._cancel_callback)
|
||||||
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
|
])
|
||||||
|
|
||||||
self.task_pool._func.reset_mock(side_effect=True)
|
|
||||||
mock__start_task.reset_mock()
|
|
||||||
|
|
||||||
# Simulate cancellation while the second task is being started.
|
|
||||||
mock__start_task.side_effect = [None, CancelledError, None]
|
|
||||||
mock_coroutine_to_close = MagicMock()
|
|
||||||
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
|
||||||
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
|
||||||
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
|
||||||
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
|
|
||||||
mock_coroutine_to_close.close.assert_called_once_with()
|
|
||||||
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
||||||
|
Loading…
Reference in New Issue
Block a user