diff --git a/src/asyncio_taskpool/internals/constants.py b/src/asyncio_taskpool/internals/constants.py index dc06e46..693f2be 100644 --- a/src/asyncio_taskpool/internals/constants.py +++ b/src/asyncio_taskpool/internals/constants.py @@ -25,8 +25,6 @@ PACKAGE_NAME = 'asyncio_taskpool' DEFAULT_TASK_GROUP = 'default' -DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S' - SESSION_MSG_BYTES = 1024 * 100 STREAM_WRITER = 'stream_writer' diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 0919593..6fbc819 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -33,12 +33,11 @@ from asyncio.exceptions import CancelledError from asyncio.locks import Semaphore from asyncio.tasks import Task, create_task, gather from contextlib import suppress -from datetime import datetime from math import inf from typing import Any, Awaitable, Dict, Iterable, List, Set, Union from . import exceptions -from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT +from .internals.constants import DEFAULT_TASK_GROUP from .internals.group_register import TaskGroupRegister from .internals.helpers import execute_optional, star_function from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB @@ -615,19 +614,24 @@ class TaskPool(BaseTaskPool): self._group_meta_tasks_running.clear() await super().gather_and_close(return_exceptions=return_exceptions) - @staticmethod - def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str: + def _generate_group_name(self, prefix: str, coroutine_function: CoroutineFunc) -> str: """ - Creates a task group identifier that includes the current datetime. + Creates a unique task group identifier. Args: prefix: The start of the name; will be followed by an underscore. coroutine_function: The function representing the task group. Returns: - The constructed 'prefix_function_datetime' string to name a task group. + The constructed 'prefix_function_index' string to name a task group. """ - return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}' + base_name = f'{prefix}_{coroutine_function.__name__}' + i = 0 + while True: + name = f'{base_name}_{i}' + if name not in self._task_groups.keys(): + return name + i += 1 async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None: @@ -683,7 +687,8 @@ class TaskPool(BaseTaskPool): num (optional): The number of tasks to spawn with the specified parameters. group_name (optional): - Name of the task group to add the new tasks to. + Name of the task group to add the new tasks to. By default, a unique name is constructed using the + name of the provided `func` and an incrementing index as 'apply_func_index'. end_callback (optional): A callback to execute after a task has ended. It is run with the task's ID as its only positional argument. @@ -692,7 +697,7 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Returns: - The name of the task group that the newly spawned tasks have been added to. + The name of the newly created task group (see the `group_name` parameter). Raises: `PoolIsClosed`: The pool is closed. @@ -711,7 +716,7 @@ class TaskPool(BaseTaskPool): @staticmethod def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB: - """Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`.""" + """Returns a wrapped `end_callback` for each :meth:`_arg_consumer` task that releases the `map_semaphore`.""" async def release_callback(task_id: int) -> None: map_semaphore.release() await execute_optional(actual_end_callback, args=(task_id,)) @@ -848,6 +853,8 @@ class TaskPool(BaseTaskPool): The number new tasks spawned by this method to run concurrently. Defaults to 1. group_name (optional): Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet. + By default, a unique name is constructed using the name of the provided `func` and an incrementing + index as 'apply_func_index'. end_callback (optional): A callback to execute after a task has ended. It is run with the task's ID as its only positional argument. @@ -856,7 +863,7 @@ class TaskPool(BaseTaskPool): It is run with the task's ID as its only positional argument. Returns: - The name of the task group that the newly spawned tasks will be added to. + The name of the newly created task group (see the `group_name` parameter). Raises: `PoolIsClosed`: The pool is closed. diff --git a/tests/test_pool.py b/tests/test_pool.py index 183e501..57444f8 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -20,13 +20,11 @@ Unittests for the `asyncio_taskpool.pool` module. from asyncio.exceptions import CancelledError from asyncio.locks import Semaphore -from datetime import datetime from unittest import IsolatedAsyncioTestCase from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from typing import Type from asyncio_taskpool import pool, exceptions -from asyncio_taskpool.internals.constants import DATETIME_FORMAT EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set() @@ -363,7 +361,8 @@ class BaseTaskPoolTestCase(CommonTestCase): self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) - async def test_gather_and_close(self): + @patch.object(pool.BaseTaskPool, 'lock') + async def test_gather_and_close(self, mock_lock: MagicMock): mock_running_func = AsyncMock() mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) self.task_pool._tasks_ended = {123: mock_ended_func()} @@ -372,6 +371,7 @@ class BaseTaskPoolTestCase(CommonTestCase): self.task_pool._locked = True self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) + mock_lock.assert_called_once_with() mock_ended_func.assert_awaited_once_with() mock_cancelled_func.assert_awaited_once_with() mock_running_func.assert_awaited_once_with() @@ -458,13 +458,15 @@ class TaskPoolTestCase(CommonTestCase): mock_cancelled_meta_task.assert_awaited_once_with() self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) + @patch.object(pool.BaseTaskPool, 'lock') @patch.object(pool.BaseTaskPool, 'gather_and_close') - async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock): + async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock, mock_lock: MagicMock): mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock() self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}} mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) + mock_lock.assert_called_once_with() mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True) mock_meta_task1.assert_awaited_once_with() mock_meta_task2.assert_awaited_once_with() @@ -472,13 +474,15 @@ class TaskPoolTestCase(CommonTestCase): self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) - @patch.object(pool, 'datetime') - def test__generate_group_name(self, mock_datetime: MagicMock): + def test__generate_group_name(self): prefix, func = 'x y z', AsyncMock(__name__=BAR) - dt = datetime(1776, 7, 4, 0, 0, 1) - mock_datetime.now = MagicMock(return_value=dt) - expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}' - output = pool.TaskPool._generate_group_name(prefix, func) + self.task_pool._task_groups = { + f'{prefix}_{BAR}_0': MagicMock(), + f'{prefix}_{BAR}_1': MagicMock(), + f'{prefix}_{BAR}_100': MagicMock(), + } + expected_output = f'{prefix}_{BAR}_2' + output = self.task_pool._generate_group_name(prefix, func) self.assertEqual(expected_output, output) @patch.object(pool.TaskPool, '_start_task')