huge update:

introduced meta tasks which are used by `_map()`;
introduced task groups;
ending with `gather_and_close()` now;
pool unittests rewritten accordingly;
two new helper classes
This commit is contained in:
Daniil Fajnberg 2022-02-24 19:16:24 +01:00
parent 4994135062
commit c63f079da4
12 changed files with 1214 additions and 715 deletions

View File

@ -15,11 +15,19 @@ If you need control over a task pool at runtime, you can launch an asynchronous
## Usage ## Usage
Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form: Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form:
```python ```python
from asyncio_taskpool import SimpleTaskPool from asyncio_taskpool import SimpleTaskPool
... ...
async def work(foo, bar): ... async def work(foo, bar): ...
... ...
async def main(): async def main():
pool = SimpleTaskPool(work, args=('xyz', 420)) pool = SimpleTaskPool(work, args=('xyz', 420))
await pool.start(5) await pool.start(5)
@ -27,11 +35,11 @@ async def main():
pool.stop(3) pool.stop(3)
... ...
pool.lock() pool.lock()
await pool.gather() await pool.gather_and_close()
... ...
``` ```
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md). For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md).

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.3.5 version = 0.4.0
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -54,10 +54,10 @@ def parse_cli() -> Dict[str, Any]:
async def main(): async def main():
kwargs = parse_cli() kwargs = parse_cli()
if kwargs[CONN_TYPE] == UNIX: if kwargs[CONN_TYPE] == UNIX:
client = UnixControlClient(path=kwargs[SOCKET_PATH]) client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
elif kwargs[CONN_TYPE] == TCP: elif kwargs[CONN_TYPE] == TCP:
# TODO: Implement the TCP client class # TODO: Implement the TCP client class
client = UnixControlClient(path=kwargs[SOCKET_PATH]) client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
else: else:
print("Invalid connection type", file=sys.stderr) print("Invalid connection type", file=sys.stderr)
sys.exit(2) sys.exit(2)

View File

@ -21,6 +21,9 @@ Constants used by more than one module in the package.
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
DEFAULT_TASK_GROUP = ''
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
CLIENT_EXIT = 'exit' CLIENT_EXIT = 'exit'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100

View File

@ -23,6 +23,10 @@ class PoolException(Exception):
pass pass
class PoolIsClosed(PoolException):
pass
class PoolIsLocked(PoolException): class PoolIsLocked(PoolException):
pass pass
@ -43,6 +47,10 @@ class InvalidTaskID(PoolException):
pass pass
class InvalidGroupName(PoolException):
pass
class PoolStillUnlocked(PoolException): class PoolStillUnlocked(PoolException):
pass pass

View File

@ -0,0 +1,75 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the definition of the `TaskGroupRegister` class.
"""
from asyncio.locks import Lock
from collections.abc import MutableSet
from typing import Iterator, Set
class TaskGroupRegister(MutableSet):
"""
This class combines the interface of a regular `set` with that of the `asyncio.Lock`.
It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for
preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of
tasks, as well as before starting a task within the group.
"""
def __init__(self, *task_ids: int) -> None:
self._ids: Set[int] = set(task_ids)
self._lock = Lock()
def __contains__(self, task_id: int) -> bool:
"""Abstract method for the `MutableSet` base class."""
return task_id in self._ids
def __iter__(self) -> Iterator[int]:
"""Abstract method for the `MutableSet` base class."""
return iter(self._ids)
def __len__(self) -> int:
"""Abstract method for the `MutableSet` base class."""
return len(self._ids)
def add(self, task_id: int) -> None:
"""Abstract method for the `MutableSet` base class."""
self._ids.add(task_id)
def discard(self, task_id: int) -> None:
"""Abstract method for the `MutableSet` base class."""
self._ids.discard(task_id)
async def acquire(self) -> bool:
"""Wrapper around the lock's `acquire()` method."""
return await self._lock.acquire()
def release(self) -> None:
"""Wrapper around the lock's `release()` method."""
self._lock.release()
async def __aenter__(self) -> None:
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
await self._lock.acquire()
return None
async def __aexit__(self, exc_type, exc, tb) -> None:
"""Provides the asynchronous context manager syntax `async with ... :` when using the lock."""
self._lock.release()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,58 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
This module contains the definition of an `asyncio.Queue` subclass.
"""
from asyncio.queues import Queue as _Queue
from typing import Any
class Queue(_Queue):
"""This just adds a little syntactic sugar to the `asyncio.Queue`."""
def item_processed(self) -> None:
"""
Does exactly the same as `task_done()`.
This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing,
invites confusion, and immensely reduces readability (in the context of this library). And readability counts.
"""
self.task_done()
async def __aenter__(self) -> Any:
"""
Implements an asynchronous context manager for the queue.
Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned.
It allows writing code this way:
>>> queue = Queue()
>>> ...
>>> async with queue as item:
>>> ...
"""
return await self.get()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""
Implements an asynchronous context manager for the queue.
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
but in some situations it makes the codes much cleaner.
"""
self.item_processed()

View File

@ -32,8 +32,8 @@ KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[[...], Awaitable[Any]]
EndCallbackT = Callable EndCB = Callable
CancelCallbackT = Callable CancelCB = Callable
ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]]
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

View File

@ -18,19 +18,20 @@ __doc__ = """
Unittests for the `asyncio_taskpool.pool` module. Unittests for the `asyncio_taskpool.pool` module.
""" """
import asyncio
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.queues import Queue from asyncio.locks import Semaphore
from asyncio.queues import QueueEmpty
from datetime import datetime
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
from asyncio_taskpool.constants import DATETIME_FORMAT
EMPTY_LIST, EMPTY_DICT = [], {} EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
FOO, BAR = 'foo', 'bar' FOO, BAR, BAZ = 'foo', 'bar', 'baz'
class TestException(Exception): class TestException(Exception):
@ -45,19 +46,12 @@ class CommonTestCase(IsolatedAsyncioTestCase):
task_pool: pool.BaseTaskPool task_pool: pool.BaseTaskPool
log_lvl: int log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def get_task_pool_init_params(self) -> dict: def get_task_pool_init_params(self) -> dict:
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME} return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
def setUp(self) -> None: def setUp(self) -> None:
self.log_lvl = pool.log.level
pool.log.setLevel(999)
self._pools = self.TEST_CLASS._pools self._pools = self.TEST_CLASS._pools
# These three methods are called during initialization, so we mock them by default during setup: # These three methods are called during initialization, so we mock them by default during setup:
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool') self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
@ -76,6 +70,7 @@ class CommonTestCase(IsolatedAsyncioTestCase):
self._add_pool_patcher.stop() self._add_pool_patcher.stop()
self.pool_size_patcher.stop() self.pool_size_patcher.stop()
self.dunder_str_patcher.stop() self.dunder_str_patcher.stop()
pool.log.setLevel(self.log_lvl)
class BaseTaskPoolTestCase(CommonTestCase): class BaseTaskPoolTestCase(CommonTestCase):
@ -88,19 +83,23 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools) self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self): def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) self.assertEqual(0, self.task_pool._num_started)
self.assertEqual(0, self.task_pool._num_cancellations)
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
self.assertEqual(0, self.task_pool._counter) self.assertFalse(self.task_pool._closed)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._ended)
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE) self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with() self.mock___str__.assert_called_once_with()
@ -143,26 +142,56 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
def test_num_running(self): def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz'] self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_running) self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self): def test_num_cancellations(self):
self.task_pool._num_cancelled = 3 self.task_pool._num_cancellations = 3
self.assertEqual(3, self.task_pool.num_cancelled) self.assertEqual(3, self.task_pool.num_cancellations)
def test_num_ended(self): def test_num_ended(self):
self.task_pool._num_ended = 3 self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(3, self.task_pool.num_ended) self.assertEqual(3, self.task_pool.num_ended)
def test_num_finished(self): def test_num_finished(self):
self.task_pool._num_cancelled = cancelled = 69 self.task_pool._num_cancellations = num_cancellations = 69
self.task_pool._num_ended = ended = 420 num_ended = 420
self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'} self.task_pool._tasks_ended = {i: FOO for i in range(num_ended)}
self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished) self.task_pool._tasks_cancelled = mock_cancelled_dict = {1: FOO, 2: BAR, 3: BAZ}
self.assertEqual(num_ended - num_cancellations + len(mock_cancelled_dict), self.task_pool.num_finished)
def test_is_full(self): def test_is_full(self):
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test_get_task_group_ids(self):
group_name, ids = 'abcdef', [1, 2, 3]
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name))
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.get_task_group_ids('something else')
async def test__check_start(self):
self.task_pool._closed = True
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try:
with self.assertRaises(AssertionError):
self.task_pool._check_start(awaitable=None, function=None)
with self.assertRaises(AssertionError):
self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function)
with self.assertRaises(exceptions.NotCoroutine):
self.task_pool._check_start(awaitable=mock_coroutine_function, function=None)
with self.assertRaises(exceptions.NotCoroutine):
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False
self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True))
finally:
await mock_coroutine
def test__task_name(self): def test__task_name(self):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@ -171,12 +200,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3 self.task_pool._num_cancellations = cancelled = 3
self.task_pool._running[task_id] = mock_task self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback)) self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running) self.assertNotIn(task_id, self.task_pool._tasks_running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) self.assertEqual(cancelled + 1, self.task_pool._num_cancellations)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -184,15 +213,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
# End running task: # End running task:
self.task_pool._running[task_id] = mock_task self.task_pool._tasks_running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running) self.assertNotIn(task_id, self.task_pool._tasks_running)
self.assertEqual(mock_task, self.task_pool._ended[task_id]) self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value) self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -200,11 +227,10 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock_execute_optional.reset_mock() mock_execute_optional.reset_mock()
# End cancelled task: # End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled) self.assertNotIn(task_id, self.task_pool._tasks_cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id]) self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value) self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id) mock__task_name.assert_called_with(task_id)
mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, ))
@ -246,92 +272,52 @@ class BaseTaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock, @patch.object(pool, 'TaskGroupRegister')
mock_create_task: MagicMock): @patch.object(pool.BaseTaskPool, '_check_start')
def reset_mocks() -> None: async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock,
mock__task_name.reset_mock() mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
mock__task_wrapper.reset_mock() mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_create_task.reset_mock()
mock_create_task.return_value = mock_task = MagicMock() mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock() mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock() mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123 self.task_pool._num_started = count = 123
self.task_pool._enough_room._value = room = 123 self.task_pool._enough_room._value = room = 123
group_name, ignore_lock = 'testgroup', True
def check_nothing_changed() -> None: output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock,
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
with self.assertRaises(exceptions.NotCoroutine):
await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
check_nothing_changed()
self.task_pool._locked = True
ignore_closed = False
mock_awaitable = mock_coroutine()
with self.assertRaises(exceptions.PoolIsLocked):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
check_nothing_changed()
ignore_closed = True
mock_awaitable = mock_coroutine()
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
await mock_awaitable
self.assertEqual(count, output) self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter) mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value) self.assertEqual(room - 1, self.task_pool._enough_room._value)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name])
mock_reg_cls.assert_called_once_with()
mock_group_reg.__aenter__.assert_awaited_once_with()
mock_group_reg.add.assert_called_once_with(count)
mock__task_name.assert_called_once_with(count) mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO)
reset_mocks() self.assertEqual(mock_task, self.task_pool._tasks_running[count])
self.task_pool._counter = count mock_group_reg.__aexit__.assert_awaited_once()
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_awaitable = mock_coroutine()
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
await mock_awaitable
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock): def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock() task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task self.task_pool._tasks_running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id) output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output) self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id) self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled): with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id) mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock() mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id) self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded): with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id) mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock() mock__task_name.reset_mock()
del self.task_pool._ended[task_id] del self.task_pool._tasks_ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID): with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
@ -344,263 +330,416 @@ class BaseTaskPoolTestCase(CommonTestCase):
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self): def test__cancel_and_remove_all_from_group(self):
mock_task1, mock_task2 = MagicMock(), MagicMock() task_id = 555
self.task_pool._running = {1: mock_task1, 2: mock_task2} mock_cancel = MagicMock()
assert not self.task_pool._interrupt_flag.is_set() self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set()) class MockRegister(set, MagicMock):
mock_task1.cancel.assert_called_once_with(msg=FOO) pass
mock_task2.cancel.assert_called_once_with(msg=FOO) self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
mock_cancel.assert_called_once_with(msg=FOO)
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[FOO] = mock_group_reg
with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called()
mock_grp_aenter.assert_not_called()
mock_grp_aexit.assert_not_called()
self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[BAR] = mock_group_reg
self.assertIsNone(await self.task_pool.cancel_all(FOO))
mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
async def test_flush(self): async def test_flush(self):
test_exception = TestException() mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._ended = {123: mock_ended_func()} self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()} self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
self.task_pool._interrupt_flag.set() mock_ended_func.assert_awaited_once_with()
output = await self.task_pool.flush(return_exceptions=True) mock_cancelled_func.assert_awaited_once_with()
self.assertListEqual([FOO, test_exception], output) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.task_pool._ended = {123: mock_ended_func()} async def test_gather_and_close(self):
self.task_pool._cancelled = {456: mock_cancelled_func()} mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
output = await self.task_pool.flush(return_exceptions=True) mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.assertListEqual([FOO, test_exception], output) self.task_pool._before_gathering = before_gather = [mock_before_gather()]
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()}
async def test_gather(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
mock_queue_join = AsyncMock()
self.task_pool._before_gathering = before_gather = [mock_queue_join()]
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert not self.task_pool._locked
with self.assertRaises(exceptions.PoolStillUnlocked): with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather() await self.task_pool.gather_and_close()
self.assertDictEqual(self.task_pool._ended, ended) self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled) self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(self.task_pool._running, running) self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(self.task_pool._before_gathering, before_gather) self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._interrupt_flag.is_set()) self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
def check_assertions(output) -> None: mock_before_gather.assert_awaited_once_with()
self.assertListEqual([FOO, test_exception, BAR], output) mock_ended_func.assert_awaited_once_with()
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) mock_cancelled_func.assert_awaited_once_with()
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) mock_running_func.assert_awaited_once_with()
self.assertDictEqual(self.task_pool._running, EMPTY_DICT) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertFalse(self.task_pool._interrupt_flag.is_set()) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
check_assertions(await self.task_pool.gather(return_exceptions=True)) self.assertTrue(self.task_pool._closed)
self.task_pool._before_gathering = [mock_queue_join()]
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
check_assertions(await self.task_pool.gather(return_exceptions=True))
class TaskPoolTestCase(CommonTestCase): class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task') def setUp(self) -> None:
async def test__apply_one(self, mock__start_task: AsyncMock): self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
mock__start_task.return_value = expected_output = 12345 self.base_class_init = self.base_class_init_patcher.start()
mock_awaitable = MagicMock() super().setUp()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} def tearDown(self) -> None:
end_cb, cancel_cb = MagicMock(), MagicMock() self.base_class_init_patcher.stop()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb) super().tearDown()
def test_init(self):
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2}
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO))
self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
mock_task1.cancel.assert_not_called()
mock_task2.cancel.assert_not_called()
self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled)
mock_task1.cancel.assert_called_once_with()
mock_task2.cancel.assert_called_once_with()
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
@patch.object(pool.TaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock,
mock_base__cancel_and_remove_all_from_group: MagicMock):
group_name, group_reg, msg = 'xyz', MagicMock(), FOO
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg))
mock__cancel_group_meta_tasks.assert_called_once_with(group_name)
mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg)
@patch.object(pool.BaseTaskPool, 'cancel_group')
async def test_cancel_group(self, mock_base_cancel_group: AsyncMock):
group_name, msg = 'abc', 'xyz'
await self.task_pool.cancel_group(group_name, msg=msg)
mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg)
@patch.object(pool.BaseTaskPool, 'cancel_all')
async def test_cancel_all(self, mock_base_cancel_all: AsyncMock):
msg = 'xyz'
await self.task_pool.cancel_all(msg=msg)
mock_base_cancel_all.assert_awaited_once_with(msg=msg)
def test__pop_ended_meta_tasks(self):
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1}
mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True)
self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3}
expected_output = {mock_done_task1, mock_done_task2, mock_done_task3}
output = self.task_pool._pop_ended_meta_tasks()
self.assertSetEqual(expected_output, output)
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
@patch.object(pool.TaskPool, '_pop_ended_meta_tasks')
@patch.object(pool.BaseTaskPool, 'flush')
async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock):
mock_ended_meta_task = AsyncMock()
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=False))
mock_base_flush.assert_awaited_once_with(return_exceptions=False)
mock__pop_ended_meta_tasks.assert_called_once_with()
mock_ended_meta_task.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with()
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool.BaseTaskPool, 'gather_and_close')
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
mock_meta_task1.assert_awaited_once_with()
mock_meta_task2.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool, 'datetime')
def test__generate_group_name(self, mock_datetime: MagicMock):
prefix, func = 'x y z', AsyncMock(__name__=BAR)
dt = datetime(1776, 7, 4, 0, 0, 1)
mock_datetime.now = MagicMock(return_value=dt)
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
output = pool.TaskPool._generate_group_name(prefix, func)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) @patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock):
group_name = FOO + BAR
mock_awaitable = object()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
mock__start_task.assert_has_awaits(3 * [
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
])
mock_func.reset_mock() mock_func.reset_mock()
mock__start_task.reset_mock() mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb) self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
self.assertEqual(expected_output, output) mock_func.assert_has_calls(num * [call(*args)])
mock_func.assert_called_once_with(*args) mock__start_task.assert_has_awaits(num * [
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
])
@patch.object(pool.TaskPool, '_apply_one') @patch.object(pool, 'create_task')
async def test_apply(self, mock__apply_one: AsyncMock): @patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
mock__apply_one.return_value = mock_id = 67890 @patch.object(pool, 'TaskGroupRegister')
mock_func, num = MagicMock(), 3 @patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start')
async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id] self.task_pool._task_groups = {}
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
async def test__queue_producer(self): def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock_group_reg.__aenter__.assert_awaited_once_with()
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine)
mock_group_reg.__aexit__.assert_awaited_once()
mock_task_future.assert_awaited_once_with()
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
check_assertions(group_name, output)
mock__generate_group_name.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
mock_group_reg.__aenter__.reset_mock()
mock__apply_num.reset_mock()
mock_create_task.reset_mock()
mock_group_reg.__aexit__.reset_mock()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
check_assertions(generated_name, output)
mock__generate_group_name.assert_called_once_with('apply', mock_func)
@patch.object(pool, 'Queue')
async def test__queue_producer(self, mock_queue_cls: MagicMock):
mock_put = AsyncMock() mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put) mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
args = (FOO, BAR, 123) item1, item2, item3 = FOO, 420, 69
assert not self.task_pool._interrupt_flag.is_set() arg_iter = iter([item1, item2, item3])
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(arg) for arg in args]) mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
with self.assertRaises(StopIteration):
next(arg_iter)
mock_put.reset_mock() mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial') mock_put.side_effect = [CancelledError, None]
@patch.object(pool, 'star_function') arg_iter = iter([item1, item2, item3])
@patch.object(pool.TaskPool, '_start_task') mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock, self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_partial: MagicMock): mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
mock_partial.return_value = queue_callback = 'not really' mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
mock_star_function.return_value = awaitable = 'totally an awaitable' mock_queue.item_processed.assert_has_calls([call(), call()])
q, arg = Queue(), 420.69 self.assertListEqual([item2, item3], list(arg_iter))
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
@patch.object(pool, 'execute_optional') @patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer') async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock): semaphore, mock_end_cb = Semaphore(1), MagicMock()
task_id, mock_q = 420, MagicMock() wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb)
mock_func, stars = MagicMock(), 3 task_id = 1234
mock_wait = AsyncMock() await wrapped(task_id)
mock_flag = MagicMock(wait=mock_wait) self.assertEqual(2, semaphore._value)
end_cb, cancel_cb = MagicMock(), MagicMock() mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,))
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)) @patch.object(pool, 'star_function')
mock_wait.assert_awaited_once_with() @patch.object(pool.TaskPool, '_start_task')
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars, @patch.object(pool, 'Semaphore')
end_callback=end_cb, cancel_callback=cancel_cb) @patch.object(pool.TaskPool, '_get_map_end_callback')
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,)) async def test__queue_consumer(self, mock__get_map_end_callback: MagicMock, mock_semaphore_cls: MagicMock,
mock__start_task: AsyncMock, mock_star_function: MagicMock):
mock__get_map_end_callback.return_value = map_cb = MagicMock()
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
mock_star_function.return_value = awaitable = 'totally an awaitable'
arg1, arg2 = 123456789, 'function argument'
mock_q_maxsize = 3
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, pool.TaskPool._QUEUE_END_SENTINEL]),
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
group_name, mock_func, stars = 'whatever', MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(semaphore.locked())
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock, @patch.object(pool, 'join_queue', new_callable=MagicMock)
mock_create_task: MagicMock, mock_iter: MagicMock): @patch.object(pool, 'Queue')
args, num_tasks = (FOO, BAR, 1, 2, 3), 2 @patch.object(pool, 'TaskGroupRegister')
mock_join_queue.return_value = mock_join = 'awaitable' @patch.object(pool.BaseTaskPool, '_check_start')
mock_iter.return_value = args_iter = iter(args) async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable' mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
output_q = self.task_pool._set_up_args_queue(args, num_tasks) mock_create_task: MagicMock):
self.assertIsInstance(output_q, Queue) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
self.assertEqual(num_tasks, output_q.qsize()) mock_queue_cls.return_value = mock_q = MagicMock()
for arg in args[:num_tasks]: mock_join_queue.return_value = fake_join = object()
self.assertEqual(arg, output_q.get_nowait()) mock__queue_producer.return_value = fake_producer = object()
self.assertTrue(output_q.empty()) mock__queue_consumer.return_value = fake_consumer = object()
for arg in args[num_tasks:]: fake_task1, fake_task2 = object(), object()
self.assertEqual(arg, next(args_iter)) mock_create_task.side_effect = [fake_task1, fake_task2]
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
self.task_pool._before_gathering.clear() group_name, group_size = 'onetwothree', 0
mock_join_queue.reset_mock() func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool, 'Event')
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock,
mock_event_cls: MagicMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_flag_set = MagicMock()
mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set)
mock_func, stars = MagicMock(), 3
args_iter, group_size = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._locked = False with self.assertRaises(ValueError):
with self.assertRaises(exceptions.PoolIsLocked): await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb) mock__check_start.assert_called_once_with(function=func)
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_flag_set.assert_not_called()
self.task_pool._locked = True mock__check_start.reset_mock()
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, group_size) group_size = 1234
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars, self.task_pool._task_groups = {group_name: MagicMock()}
end_callback=end_cb, cancel_callback=cancel_cb)])
mock_flag_set.assert_called_once_with() with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock()
self.task_pool._task_groups.clear()
self.task_pool._before_gathering = []
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
mock__check_start.assert_called_once_with(function=func)
mock_reg_cls.assert_called_once_with()
self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with()
mock_queue_cls.assert_called_once_with(maxsize=group_size)
mock_join_queue.assert_called_once_with(mock_q)
self.assertListEqual([fake_join], self.task_pool._before_gathering)
mock__queue_producer.assert_called_once()
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock): @patch.object(pool.TaskPool, '_generate_group_name')
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2 arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb)) output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size, self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('map', mock_func)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock): @patch.object(pool.TaskPool, '_generate_group_name')
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
args_iter, group_size = ([FOO], [BAR]), 2 args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb)) output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size, self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock): @patch.object(pool.TaskPool, '_generate_group_name')
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2 kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb)) output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size, self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called()
mock__map.reset_mock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
class SimpleTaskPoolTestCase(CommonTestCase): class SimpleTaskPoolTestCase(CommonTestCase):
@ -667,7 +806,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def test_stop(self, mock_cancel: MagicMock): def test_stop(self, mock_cancel: MagicMock):
num = 2 num = 2
id1, id2, id3 = 5, 6, 7 id1, id2, id3 = 5, 6, 7
self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR} self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR}
output = self.task_pool.stop(num) output = self.task_pool.stop(num)
expected_output = [id3, id2] expected_output = [id3, id2]
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
@ -689,3 +828,10 @@ class SimpleTaskPoolTestCase(CommonTestCase):
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
mock_num_running.assert_called_once_with() mock_num_running.assert_called_once_with()
mock_stop.assert_called_once_with(num) mock_stop.assert_called_once_with(num)
def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock:
mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock()
mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit,
add=mock_grp_add)
return mock_group_reg

View File

@ -28,7 +28,7 @@ async def work(n: int) -> None:
""" """
for i in range(n): for i in range(n):
await asyncio.sleep(1) await asyncio.sleep(1)
print("did", i) print("> did", i)
async def main() -> None: async def main() -> None:
@ -39,7 +39,7 @@ async def main() -> None:
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 pool.stop(2) # cancels tasks 3 and 2
pool.lock() # required for the last line pool.lock() # required for the last line
await pool.gather() # awaits all tasks, then flushes the pool await pool.gather_and_close() # awaits all tasks, then flushes the pool
if __name__ == '__main__': if __name__ == '__main__':
@ -52,29 +52,29 @@ SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0 Started SimpleTaskPool-0_Task-0
Started SimpleTaskPool-0_Task-1 Started SimpleTaskPool-0_Task-1
Started SimpleTaskPool-0_Task-2 Started SimpleTaskPool-0_Task-2
did 0 > did 0
did 0 > did 0
did 0 > did 0
Started SimpleTaskPool-0_Task-3 Started SimpleTaskPool-0_Task-3
did 1 > did 1
did 1 > did 1
did 1 > did 1
did 0 > did 0
> did 2
> did 2
SimpleTaskPool-0 is locked! SimpleTaskPool-0 is locked!
Cancelling SimpleTaskPool-0_Task-3 ...
Cancelled SimpleTaskPool-0_Task-3
Ended SimpleTaskPool-0_Task-3
Cancelling SimpleTaskPool-0_Task-2 ... Cancelling SimpleTaskPool-0_Task-2 ...
Cancelled SimpleTaskPool-0_Task-2 Cancelled SimpleTaskPool-0_Task-2
Ended SimpleTaskPool-0_Task-2 Ended SimpleTaskPool-0_Task-2
did 2 Cancelling SimpleTaskPool-0_Task-3 ...
did 2 Cancelled SimpleTaskPool-0_Task-3
did 3 Ended SimpleTaskPool-0_Task-3
did 3 > did 3
> did 3
Ended SimpleTaskPool-0_Task-0 Ended SimpleTaskPool-0_Task-0
Ended SimpleTaskPool-0_Task-1 Ended SimpleTaskPool-0_Task-1
did 4 > did 4
did 4 > did 4
``` ```
## Advanced example for `TaskPool` ## Advanced example for `TaskPool`
@ -101,21 +101,21 @@ async def work(start: int, stop: int, step: int = 1) -> None:
"""Pseudo-worker function counting through a range with a second of sleep in between each iteration.""" """Pseudo-worker function counting through a range with a second of sleep in between each iteration."""
for i in range(start, stop, step): for i in range(start, stop, step):
await asyncio.sleep(1) await asyncio.sleep(1)
print("work with", i) print("> work with", i)
async def other_work(a: int, b: int) -> None: async def other_work(a: int, b: int) -> None:
"""Different pseudo-worker counting through a range with half a second of sleep in between each iteration.""" """Different pseudo-worker counting through a range with half a second of sleep in between each iteration."""
for i in range(a, b): for i in range(a, b):
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
print("other_work with", i) print("> other_work with", i)
async def main() -> None: async def main() -> None:
# Initialize a new task pool instance and limit its size to 3 tasks. # Initialize a new task pool instance and limit its size to 3 tasks.
pool = TaskPool(3) pool = TaskPool(3)
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments). # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
print("Called `apply`") print("> Called `apply`")
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
# Let the tasks work for a bit. # Let the tasks work for a bit.
await asyncio.sleep(1.5) await asyncio.sleep(1.5)
@ -124,20 +124,18 @@ async def main() -> None:
# Since we set our pool size to 3, and already have two tasks working within the pool, # Since we set our pool size to 3, and already have two tasks working within the pool,
# only the first one of these will start immediately (and receive ID 2). # only the first one of these will start immediately (and receive ID 2).
# The second one will start (with ID 3), only once there is room in the pool, # The second one will start (with ID 3), only once there is room in the pool,
# which -- in this example -- will be the case after ID 2 ends; # which -- in this example -- will be the case after ID 2 ends.
# until then the `starmap` method call **will block**!
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
# **only** once there is room in the pool **and** no more than one of these last four tasks is running. # **only** once there is room in the pool **and** no more than one other task of these new ones is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
print("Calling `starmap`...")
await pool.starmap(other_work, args_list, group_size=2) await pool.starmap(other_work, args_list, group_size=2)
print("`starmap` returned") print("> Called `starmap`")
# Now we lock the pool, so that we can safely await all our tasks. # Now we lock the pool, so that we can safely await all our tasks.
pool.lock() pool.lock()
# Finally, we block, until all tasks have ended. # Finally, we block, until all tasks have ended.
print("Called `gather`") print("> Calling `gather_and_close`...")
await pool.gather() await pool.gather_and_close()
print("Done.") print("> Done.")
if __name__ == '__main__': if __name__ == '__main__':
@ -152,82 +150,81 @@ Additional comments for the output are provided with `<---` next to the output l
TaskPool-0 initialized TaskPool-0 initialized
Started TaskPool-0_Task-0 Started TaskPool-0_Task-0
Started TaskPool-0_Task-1 Started TaskPool-0_Task-1
Called `apply` > Called `apply`
work with 100 > work with 100
work with 100 > work with 100
Calling `starmap`... <--- notice that this blocks as expected > Called `starmap` <--- notice that this immediately returns, even before Task-2 is started
Started TaskPool-0_Task-2 > Calling `gather_and_close`... <--- this blocks `main()` until all tasks have ended
work with 110
work with 110
other_work with 0
other_work with 1
work with 120
work with 120
other_work with 2
other_work with 3
work with 130
work with 130
other_work with 4
other_work with 5
work with 140
work with 140
other_work with 6
other_work with 7
work with 150
work with 150
other_work with 8
Ended TaskPool-0_Task-2 <--- here Task-2 makes room in the pool and unblocks `main()`
TaskPool-0 is locked! TaskPool-0 is locked!
Started TaskPool-0_Task-2 <--- at this point the pool is full
> work with 110
> work with 110
> other_work with 0
> other_work with 1
> work with 120
> work with 120
> other_work with 2
> other_work with 3
> work with 130
> work with 130
> other_work with 4
> other_work with 5
> work with 140
> work with 140
> other_work with 6
> other_work with 7
> work with 150
> work with 150
> other_work with 8
Ended TaskPool-0_Task-2 <--- this frees up room for one more task from `starmap`
Started TaskPool-0_Task-3 Started TaskPool-0_Task-3
other_work with 9 > other_work with 9
`starmap` returned > work with 160
Called `gather` <--- now this will block `main()` until all tasks have ended > work with 160
work with 160 > other_work with 10
work with 160 > other_work with 11
other_work with 10 > work with 170
other_work with 11 > work with 170
work with 170 > other_work with 12
work with 170 > other_work with 13
other_work with 12 > work with 180
other_work with 13 > work with 180
work with 180 > other_work with 14
work with 180 > other_work with 15
other_work with 14
other_work with 15
Ended TaskPool-0_Task-0 Ended TaskPool-0_Task-0
Ended TaskPool-0_Task-1 <--- even though there is room in the pool now, Task-5 will not start Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
Started TaskPool-0_Task-4 Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
work with 190 > work with 190
work with 190 > work with 190
other_work with 16 > other_work with 16
other_work with 20 > other_work with 17
other_work with 17 > other_work with 20
other_work with 21 > other_work with 18
other_work with 18 > other_work with 21
other_work with 22 Ended TaskPool-0_Task-3 <--- now that only Task-4 of the group remains, Task-5 starts
other_work with 19
Ended TaskPool-0_Task-3 <--- now that only Task-4 is left, Task-5 will start
Started TaskPool-0_Task-5 Started TaskPool-0_Task-5
other_work with 23 > other_work with 19
other_work with 30 > other_work with 22
other_work with 24 > other_work with 23
other_work with 31 > other_work with 30
other_work with 25 > other_work with 24
other_work with 32 > other_work with 31
other_work with 26 > other_work with 25
other_work with 33 > other_work with 32
other_work with 27 > other_work with 26
other_work with 34 > other_work with 33
other_work with 28 > other_work with 27
other_work with 35 > other_work with 34
> other_work with 28
> other_work with 35
> other_work with 29
> other_work with 36
Ended TaskPool-0_Task-4 Ended TaskPool-0_Task-4
other_work with 29 > other_work with 37
other_work with 36 > other_work with 38
other_work with 37 > other_work with 39
other_work with 38
other_work with 39
Done.
Ended TaskPool-0_Task-5 Ended TaskPool-0_Task-5
> Done.
``` ```
© 2022 Daniil Fajnberg © 2022 Daniil Fajnberg

View File

@ -74,12 +74,12 @@ async def main() -> None:
control_server_task.cancel() control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.stop_all()
pool.lock() pool.lock()
pool.stop_all()
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
# we just silently collect their exceptions along with their return values. # we just silently collect their exceptions along with their return values.
await pool.gather(return_exceptions=True) await pool.gather_and_close(return_exceptions=True)
await control_server_task await control_server_task