Compare commits

...

2 Commits

Author SHA1 Message Date
82e6ca7b1a
task group naming logic changed 2022-03-30 12:37:32 +02:00
153127e028
lock before gathering meta tasks 2022-03-30 11:47:15 +02:00
3 changed files with 33 additions and 23 deletions

View File

@ -25,8 +25,6 @@ PACKAGE_NAME = 'asyncio_taskpool'
DEFAULT_TASK_GROUP = 'default' DEFAULT_TASK_GROUP = 'default'
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer' STREAM_WRITER = 'stream_writer'

View File

@ -33,12 +33,11 @@ from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
from asyncio.tasks import Task, create_task, gather from asyncio.tasks import Task, create_task, gather
from contextlib import suppress from contextlib import suppress
from datetime import datetime
from math import inf from math import inf
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
from . import exceptions from . import exceptions
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT from .internals.constants import DEFAULT_TASK_GROUP
from .internals.group_register import TaskGroupRegister from .internals.group_register import TaskGroupRegister
from .internals.helpers import execute_optional, star_function from .internals.helpers import execute_optional, star_function
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
@ -607,6 +606,7 @@ class TaskPool(BaseTaskPool):
Raises: Raises:
`PoolStillUnlocked`: The pool has not been locked yet. `PoolStillUnlocked`: The pool has not been locked yet.
""" """
self.lock()
not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set) not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
with suppress(CancelledError): with suppress(CancelledError):
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions) await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
@ -614,19 +614,24 @@ class TaskPool(BaseTaskPool):
self._group_meta_tasks_running.clear() self._group_meta_tasks_running.clear()
await super().gather_and_close(return_exceptions=return_exceptions) await super().gather_and_close(return_exceptions=return_exceptions)
@staticmethod def _generate_group_name(self, prefix: str, coroutine_function: CoroutineFunc) -> str:
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
""" """
Creates a task group identifier that includes the current datetime. Creates a unique task group identifier.
Args: Args:
prefix: The start of the name; will be followed by an underscore. prefix: The start of the name; will be followed by an underscore.
coroutine_function: The function representing the task group. coroutine_function: The function representing the task group.
Returns: Returns:
The constructed 'prefix_function_datetime' string to name a task group. The constructed 'prefix_function_index' string to name a task group.
""" """
return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}' base_name = f'{prefix}_{coroutine_function.__name__}'
i = 0
while True:
name = f'{base_name}_{i}'
if name not in self._task_groups.keys():
return name
i += 1
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
@ -682,7 +687,8 @@ class TaskPool(BaseTaskPool):
num (optional): num (optional):
The number of tasks to spawn with the specified parameters. The number of tasks to spawn with the specified parameters.
group_name (optional): group_name (optional):
Name of the task group to add the new tasks to. Name of the task group to add the new tasks to. By default, a unique name is constructed using the
name of the provided `func` and an incrementing index as 'apply_func_index'.
end_callback (optional): end_callback (optional):
A callback to execute after a task has ended. A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
@ -691,7 +697,7 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
Returns: Returns:
The name of the task group that the newly spawned tasks have been added to. The name of the newly created task group (see the `group_name` parameter).
Raises: Raises:
`PoolIsClosed`: The pool is closed. `PoolIsClosed`: The pool is closed.
@ -710,7 +716,7 @@ class TaskPool(BaseTaskPool):
@staticmethod @staticmethod
def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
"""Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`.""" """Returns a wrapped `end_callback` for each :meth:`_arg_consumer` task that releases the `map_semaphore`."""
async def release_callback(task_id: int) -> None: async def release_callback(task_id: int) -> None:
map_semaphore.release() map_semaphore.release()
await execute_optional(actual_end_callback, args=(task_id,)) await execute_optional(actual_end_callback, args=(task_id,))
@ -847,6 +853,8 @@ class TaskPool(BaseTaskPool):
The number new tasks spawned by this method to run concurrently. Defaults to 1. The number new tasks spawned by this method to run concurrently. Defaults to 1.
group_name (optional): group_name (optional):
Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet. Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
By default, a unique name is constructed using the name of the provided `func` and an incrementing
index as 'apply_func_index'.
end_callback (optional): end_callback (optional):
A callback to execute after a task has ended. A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
@ -855,7 +863,7 @@ class TaskPool(BaseTaskPool):
It is run with the task's ID as its only positional argument. It is run with the task's ID as its only positional argument.
Returns: Returns:
The name of the task group that the newly spawned tasks will be added to. The name of the newly created task group (see the `group_name` parameter).
Raises: Raises:
`PoolIsClosed`: The pool is closed. `PoolIsClosed`: The pool is closed.

View File

@ -20,13 +20,11 @@ Unittests for the `asyncio_taskpool.pool` module.
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Semaphore
from datetime import datetime
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
from asyncio_taskpool.internals.constants import DATETIME_FORMAT
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set() EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
@ -363,7 +361,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
async def test_gather_and_close(self): @patch.object(pool.BaseTaskPool, 'lock')
async def test_gather_and_close(self, mock_lock: MagicMock):
mock_running_func = AsyncMock() mock_running_func = AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._tasks_ended = {123: mock_ended_func()} self.task_pool._tasks_ended = {123: mock_ended_func()}
@ -372,6 +371,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._locked = True self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_lock.assert_called_once_with()
mock_ended_func.assert_awaited_once_with() mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with() mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with() mock_running_func.assert_awaited_once_with()
@ -458,13 +458,15 @@ class TaskPoolTestCase(CommonTestCase):
mock_cancelled_meta_task.assert_awaited_once_with() mock_cancelled_meta_task.assert_awaited_once_with()
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool.BaseTaskPool, 'lock')
@patch.object(pool.BaseTaskPool, 'gather_and_close') @patch.object(pool.BaseTaskPool, 'gather_and_close')
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock): async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock, mock_lock: MagicMock):
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock() mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}} self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_lock.assert_called_once_with()
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True) mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
mock_meta_task1.assert_awaited_once_with() mock_meta_task1.assert_awaited_once_with()
mock_meta_task2.assert_awaited_once_with() mock_meta_task2.assert_awaited_once_with()
@ -472,13 +474,15 @@ class TaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
@patch.object(pool, 'datetime') def test__generate_group_name(self):
def test__generate_group_name(self, mock_datetime: MagicMock):
prefix, func = 'x y z', AsyncMock(__name__=BAR) prefix, func = 'x y z', AsyncMock(__name__=BAR)
dt = datetime(1776, 7, 4, 0, 0, 1) self.task_pool._task_groups = {
mock_datetime.now = MagicMock(return_value=dt) f'{prefix}_{BAR}_0': MagicMock(),
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}' f'{prefix}_{BAR}_1': MagicMock(),
output = pool.TaskPool._generate_group_name(prefix, func) f'{prefix}_{BAR}_100': MagicMock(),
}
expected_output = f'{prefix}_{BAR}_2'
output = self.task_pool._generate_group_name(prefix, func)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
@patch.object(pool.TaskPool, '_start_task') @patch.object(pool.TaskPool, '_start_task')