Compare commits

..

9 Commits

9 changed files with 67 additions and 44 deletions

View File

@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg' author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '1.0.2' release = '1.1.4'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 1.0.2 version = 1.1.4
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -35,7 +35,7 @@ __all__ = []
CLIENT_CLASS = 'client_class' CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp' UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path' SOCKET_PATH = 'socket_path'
HOST, PORT = 'host', 'port' HOST, PORT = 'host', 'port'

View File

@ -85,6 +85,7 @@ class ControlClient(ABC):
""" """
self._connected = True self._connected = True
writer.write(json.dumps(self._client_info()).encode()) writer.write(json.dumps(self._client_info()).encode())
writer.write(b'\n')
await writer.drain() await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
print("Type '-h' to get help and usage instructions for all available commands.\n") print("Type '-h' to get help and usage instructions for all available commands.\n")
@ -131,6 +132,7 @@ class ControlClient(ABC):
try: try:
# Send the command to the server. # Send the command to the server.
writer.write(cmd.encode()) writer.write(cmd.encode())
writer.write(b'\n')
await writer.drain() await writer.drain()
except ConnectionError as e: except ConnectionError as e:
self._connected = False self._connected = False

View File

@ -32,7 +32,7 @@ from typing import Callable, Optional, Union, TYPE_CHECKING
from .parser import ControlParser from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
from ..internals.helpers import return_or_exception from ..internals.helpers import return_or_exception
if TYPE_CHECKING: if TYPE_CHECKING:
@ -103,7 +103,7 @@ class ControlSession:
elif param.kind == param.VAR_POSITIONAL: elif param.kind == param.VAR_POSITIONAL:
var_pos = kwargs.pop(param.name) var_pos = kwargs.pop(param.name)
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs) output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
self._writer.write(CMD_OK if output is None else str(output).encode()) self._response_buffer.write(CMD_OK.decode() if output is None else str(output))
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None: async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
""" """
@ -122,10 +122,10 @@ class ControlSession:
if kwargs: if kwargs:
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__) log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs) await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK) self._response_buffer.write(CMD_OK.decode())
else: else:
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__) log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode()) self._response_buffer.write(str(await return_or_exception(prop.fget, self._pool)))
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
""" """
@ -134,7 +134,8 @@ class ControlSession:
Client info is retrieved, server info is sent back, and the Client info is retrieved, server info is sent back, and the
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up. :class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
""" """
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) msg = (await self._reader.readline()).decode().strip()
client_info = json.loads(msg)
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
parser_kwargs = { parser_kwargs = {
'stream': self._response_buffer, 'stream': self._response_buffer,
@ -146,7 +147,7 @@ class ControlSession:
self._parser.add_subparsers(title="Commands", self._parser.add_subparsers(title="Commands",
metavar="(A command followed by '-h' or '--help' will show command-specific help.)") metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
self._parser.add_class_commands(self._pool.__class__) self._parser.add_class_commands(self._pool.__class__)
self._writer.write(str(self._pool).encode()) self._writer.write(str(self._pool).encode() + b'\n')
await self._writer.drain() await self._writer.drain()
async def _parse_command(self, msg: str) -> None: async def _parse_command(self, msg: str) -> None:
@ -187,12 +188,12 @@ class ControlSession:
It will obviously block indefinitely. It will obviously block indefinitely.
""" """
while self._control_server.is_serving(): while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() msg = (await self._reader.readline()).decode().strip()
if not msg: if not msg:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break
await self._parse_command(msg) await self._parse_command(msg)
response = self._response_buffer.getvalue() response = self._response_buffer.getvalue() + "\n"
self._response_buffer.seek(0) self._response_buffer.seek(0)
self._response_buffer.truncate() self._response_buffer.truncate()
self._writer.write(response.encode()) self._writer.write(response.encode())

View File

@ -31,7 +31,7 @@ import logging
import warnings import warnings
from asyncio.coroutines import iscoroutine, iscoroutinefunction from asyncio.coroutines import iscoroutine, iscoroutinefunction
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task, gather from asyncio.tasks import Task, create_task, gather
from contextlib import suppress from contextlib import suppress
from math import inf from math import inf
@ -72,7 +72,7 @@ class BaseTaskPool:
# Initialize flags; immutably set the name. # Initialize flags; immutably set the name.
self._locked: bool = False self._locked: bool = False
self._closed: bool = False self._closed: Event = Event()
self._name: str = name self._name: str = name
# The following three dictionaries are the actual containers of the tasks controlled by the pool. # The following three dictionaries are the actual containers of the tasks controlled by the pool.
@ -221,7 +221,7 @@ class BaseTaskPool:
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}") raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
if function and not iscoroutinefunction(function): if function and not iscoroutinefunction(function):
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}") raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
if self._closed: if self._closed.is_set():
raise exceptions.PoolIsClosed("You must use another pool") raise exceptions.PoolIsClosed("You must use another pool")
if self._locked and not ignore_lock: if self._locked and not ignore_lock:
raise exceptions.PoolIsLocked("Cannot start new tasks") raise exceptions.PoolIsLocked("Cannot start new tasks")
@ -550,9 +550,16 @@ class BaseTaskPool:
self._tasks_ended.clear() self._tasks_ended.clear()
self._tasks_cancelled.clear() self._tasks_cancelled.clear()
self._tasks_running.clear() self._tasks_running.clear()
self._closed = True self._closed.set()
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
# await it to allow blocking until a closing command comes from a server. async def until_closed(self) -> bool:
"""
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
Returns:
`True` once the pool is closed.
"""
return await self._closed.wait()
class TaskPool(BaseTaskPool): class TaskPool(BaseTaskPool):
@ -758,9 +765,10 @@ class TaskPool(BaseTaskPool):
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int, def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
""" """
Creates tasks in the pool with arguments from the supplied iterable. Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`. Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
The method is a task-based equivalent of the `multiprocessing.pool.Pool.map` method.
All the new tasks are added to the same task group. All the new tasks are added to the same task group.
@ -812,10 +820,10 @@ class TaskPool(BaseTaskPool):
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None, def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
""" """
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool. Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. The method is a task-based
equivalent of the `multiprocessing.pool.Pool.map` method.
All the new tasks are added to the same task group. All the new tasks are added to the same task group.
@ -869,6 +877,8 @@ class TaskPool(BaseTaskPool):
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None, def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
""" """
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
as positional arguments to the function. as positional arguments to the function.
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`. Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
@ -886,6 +896,8 @@ class TaskPool(BaseTaskPool):
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1, def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str: group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
""" """
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
unpacked as keyword-arguments to the function. unpacked as keyword-arguments to the function.
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`. Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.

View File

@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected) self.assertTrue(self.client._connected)
mock__client_info.assert_called_once_with() mock__client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode()) self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_has_calls([ self.mock_print.assert_has_calls([
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
mock__get_command.return_value = cmd = FOO + BAR + ' 123' mock__get_command.return_value = cmd = FOO + BAR + ' 123'
self.mock_drain.side_effect = err = ConnectionError() self.mock_drain.side_effect = err = ConnectionError()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode()) self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_not_awaited() self.mock_read.assert_not_awaited()
self.mock_print.assert_called_once_with(err, file=sys.stderr) self.mock_print.assert_called_once_with(err, file=sys.stderr)
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.mock_print.reset_mock() self.mock_print.reset_mock()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode()) self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with(FOO) self.mock_print.assert_called_once_with(FOO)

View File

@ -26,7 +26,7 @@ from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session from asyncio_taskpool.control import session
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
from asyncio_taskpool.exceptions import HelpRequested from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool from asyncio_taskpool.pool import SimpleTaskPool
@ -74,7 +74,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_return_or_exception.assert_awaited_once_with( mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
) )
self.mock_writer.write.assert_called_once_with(session.CMD_OK) self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
@patch.object(session, 'return_or_exception') @patch.object(session, 'return_or_exception')
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock): async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
@ -85,15 +85,16 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_return_or_exception.return_value = None mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs)) self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs) mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
self.mock_writer.write.assert_called_once_with(session.CMD_OK) self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
mock_return_or_exception.reset_mock() mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock() self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_return_or_exception.return_value = val = 420.69 mock_return_or_exception.return_value = val = 420.69
self.assertIsNone(await self.session._exec_property_and_respond(prop)) self.assertIsNone(await self.session._exec_property_and_respond(prop))
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool) mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
self.mock_writer.write.assert_called_once_with(str(val).encode()) self.assertEqual(str(val), self.session._response_buffer.getvalue())
@patch.object(session, 'ControlParser') @patch.object(session, 'ControlParser')
async def test_client_handshake(self, mock_parser_cls: MagicMock): async def test_client_handshake(self, mock_parser_cls: MagicMock):
@ -102,8 +103,8 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parser_cls.return_value = mock_parser mock_parser_cls.return_value = mock_parser
width = 5678 width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' ' msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode()) mock_readline = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read self.mock_reader.readline = mock_readline
self.mock_writer.drain = AsyncMock() self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = { expected_parser_kwargs = {
'stream': self.session._response_buffer, 'stream': self.session._response_buffer,
@ -117,11 +118,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
} }
self.assertIsNone(await self.session.client_handshake()) self.assertIsNone(await self.session.client_handshake())
self.assertEqual(mock_parser, self.session._parser) self.assertEqual(mock_parser, self.session._parser)
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) mock_readline.assert_awaited_once_with()
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs) mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs) mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__) mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session.ControlSession, '_exec_property_and_respond') @patch.object(session.ControlSession, '_exec_property_and_respond')
@ -190,27 +191,27 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
@patch.object(session.ControlSession, '_parse_command') @patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock): async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty(): def make_reader_return_empty():
self.mock_reader.read.return_value = b'' self.mock_reader.readline.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating" msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO response = FOO + BAR + FOO
self.session._response_buffer.write(response) self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) self.mock_reader.readline.assert_has_awaits([call(), call()])
mock__parse_command.assert_awaited_once_with(msg) mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue()) self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode()) self.mock_writer.write.assert_called_once_with(response.encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock() self.mock_reader.readline.reset_mock()
mock__parse_command.reset_mock() mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock() self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock() self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False) self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited() self.mock_reader.readline.assert_not_awaited()
mock__parse_command.assert_not_awaited() mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called() self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited() self.mock_writer.drain.assert_not_awaited()

View File

@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
""" """
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(0, self.task_pool._num_started) self.assertEqual(0, self.task_pool._num_started)
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed) self.assertIsInstance(self.task_pool._closed, Event)
self.assertFalse(self.task_pool._closed.is_set())
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.get_group_ids(group_name, 'something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed.set()
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try: try:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._check_start(awaitable=None, function=mock_coroutine) self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None) self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False self.task_pool._closed.clear()
self.task_pool._locked = True self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked): with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
@ -461,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertTrue(self.task_pool._closed) self.assertTrue(self.task_pool._closed.is_set())
async def test_until_closed(self):
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
output = await self.task_pool.until_closed()
self.assertEqual(FOO, output)
self.task_pool._closed.wait.assert_awaited_once_with()
class TaskPoolTestCase(CommonTestCase): class TaskPoolTestCase(CommonTestCase):