diff --git a/README.md b/README.md index 9b67b16..4f5bce2 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ async def work(_foo, _bar): ... async def main(): pool = SimpleTaskPool(work, args=('xyz', 420)) - await pool.start(5) + pool.start(5) ... pool.stop(3) ... diff --git a/docs/source/pages/pool.rst b/docs/source/pages/pool.rst index 6b57883..3d9a2d0 100644 --- a/docs/source/pages/pool.rst +++ b/docs/source/pages/pool.rst @@ -147,9 +147,11 @@ Or we could use a task pool: Calling the :py:meth:`.map() ` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction). +The :py:meth:`.gather_and_close() ` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`) + .. note:: - The :py:meth:`.gather_and_close() ` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`) + Neither :py:meth:`.apply() ` nor :py:meth:`.map() ` return coroutines. When they are called, the task pool immediately begins scheduling new tasks to run. No :code:`await` needed. It can't get any simpler than that, can it? So glad you asked... @@ -168,7 +170,7 @@ Let's take the :ref:`queue worker example ` from before. async def main(): ... pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out)) - await pool.start(5) + pool.start(5) ... pool.stop_all() ... @@ -193,9 +195,9 @@ This may, at first glance, not seem like much of a difference, aside from differ if some_condition and pool.num_running > 10: pool.stop(3) elif some_other_condition and pool.num_running < 5: - await pool.start(5) + pool.start(5) else: - await pool.start(1) + pool.start(1) ... await pool.gather_and_close() @@ -228,6 +230,4 @@ The only method of a pool that one should **always** assume to be blocking is :p One method to be aware of is :py:meth:`.flush() `. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations. -In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, :py:meth:`SimpleTaskPool.start() ` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased). - -:py:meth:`TaskPool.apply() ` and :py:meth:`TaskPool.map() ` (and its variants) will **never** block. Since they make use of "meta-tasks" under the hood, they will always return immediately. However, if the pool was full when one of them was called, there is **no guarantee** that even a single task has started, when the method returns. +All methods that add tasks to a pool, i.e. :py:meth:`TaskPool.map() ` (and its variants), :py:meth:`TaskPool.apply() ` and :py:meth:`SimpleTaskPool.start() `, are non-blocking by design. They all make use of "meta tasks" under the hood and return immediately. It is important however, to realize that just because they return, does not mean that any actual tasks have been spawned. For example, if a pool size limit was set and there was "no more room" in the pool when :py:meth:`.map() ` was called, there is **no guarantee** that even a single task has started, when it returns. diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index cab3000..ae1ea92 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -593,6 +593,7 @@ class TaskPool(BaseTaskPool): """ if kwargs is None: kwargs = {} + # TODO: Add exception logging await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback, cancel_callback=cancel_callback) for _ in range(num))) @@ -610,8 +611,8 @@ class TaskPool(BaseTaskPool): because this method returns immediately, this does not mean that any task was started or that any number of tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`. - If the entire task group is cancelled, the meta task is cancelled first, which may cause the number of tasks - spawned to be less than `num`. + If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first, + the number of tasks spawned will end up being less than `num`. Args: func: @@ -640,10 +641,13 @@ class TaskPool(BaseTaskPool): `PoolIsClosed`: The pool is closed. `NotCoroutine`: `func` is not a coroutine function. `PoolIsLocked`: The pool is currently locked. + `InvalidGroupName`: A group named `group_name` exists in the pool. """ self._check_start(function=func) if group_name is None: group_name = self._generate_group_name('apply', func) + if group_name in self._task_groups.keys(): + raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!") self._task_groups.setdefault(group_name, TaskGroupRegister()) meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num, @@ -913,16 +917,25 @@ class SimpleTaskPool(BaseTaskPool): """Name of the coroutine function used in the pool.""" return self._func.__name__ - async def _start_one(self, group_name: str) -> int: - """Starts a single new task within the pool and returns its ID.""" - return await self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name, - end_callback=self._end_callback, cancel_callback=self._cancel_callback) + async def _start_num(self, num: int, group_name: str) -> None: + """Starts `num` new tasks in group `group_name`.""" + start_coroutines = ( + self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name, + end_callback=self._end_callback, cancel_callback=self._cancel_callback) + for _ in range(num) + ) + await gather(*start_coroutines) - async def start(self, num: int) -> str: + def start(self, num: int) -> str: """ Starts specified number of new tasks in the pool as a new group. - This method may block if there is less room in the pool than the desired number of new tasks. + Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just + because this method returns immediately, this does not mean that any task was started or that any number of + tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`. + + If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first, + the number of tasks spawned will end up being less than `num`. Args: num: The number of new tasks to start. @@ -931,9 +944,12 @@ class SimpleTaskPool(BaseTaskPool): The name of the newly created task group in the form :code:`'start-group-{idx}'` (with `idx` being an incrementing index). """ + self._check_start(function=self._func) group_name = f'start-group-{self._start_calls}' self._start_calls += 1 - await gather(*(self._start_one(group_name) for _ in range(num))) + self._task_groups.setdefault(group_name, TaskGroupRegister()) + meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set()) + meta_tasks.add(create_task(self._start_num(num, group_name))) return group_name def stop(self, num: int) -> List[int]: diff --git a/tests/test_pool.py b/tests/test_pool.py index f4bb175..689c5a7 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -711,25 +711,42 @@ class SimpleTaskPoolTestCase(CommonTestCase): self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name) @patch.object(pool.SimpleTaskPool, '_start_task') - async def test__start_one(self, mock__start_task: AsyncMock): - mock__start_task.return_value = expected_output = 99 - self.task_pool._func = MagicMock(return_value=BAR) + async def test__start_num(self, mock__start_task: AsyncMock): + fake_coroutine = object() + self.task_pool._func = MagicMock(return_value=fake_coroutine) + num = 3 group_name = FOO + BAR + 'abc' - output = await self.task_pool._start_one(group_name) - self.assertEqual(expected_output, output) - self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs) - mock__start_task.assert_awaited_once_with(BAR, group_name=group_name, end_callback=self.task_pool._end_callback, - cancel_callback=self.task_pool._cancel_callback) + self.assertIsNone(await self.task_pool._start_num(num, group_name)) + self.task_pool._func.assert_has_calls(num * [ + call(*self.task_pool._args, **self.task_pool._kwargs) + ]) + mock__start_task.assert_has_awaits(num * [ + call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback, + cancel_callback=self.task_pool._cancel_callback) + ]) - @patch.object(pool.SimpleTaskPool, '_start_one') - async def test_start(self, mock__start_one: AsyncMock): - mock__start_one.return_value = FOO + @patch.object(pool, 'create_task') + @patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock()) + @patch.object(pool, 'TaskGroupRegister') + @patch.object(pool.BaseTaskPool, '_check_start') + def test_start(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__start_num: AsyncMock, + mock_create_task: MagicMock): + mock_group_reg = set_up_mock_group_register(mock_reg_cls) + mock__start_num.return_value = mock_start_num_coroutine = object() + mock_create_task.return_value = fake_task = object() + self.task_pool._task_groups = {} + self.task_pool._group_meta_tasks_running = {} num = 5 self.task_pool._start_calls = 42 - output = await self.task_pool.start(num) - expected_output = 'start-group-42' - self.assertEqual(expected_output, output) - mock__start_one.assert_has_awaits(num * [call(expected_output)]) + expected_group_name = 'start-group-42' + output = self.task_pool.start(num) + self.assertEqual(expected_group_name, output) + mock__check_start.assert_called_once_with(function=self.TEST_POOL_FUNC) + self.assertEqual(43, self.task_pool._start_calls) + self.assertEqual(mock_group_reg, self.task_pool._task_groups[expected_group_name]) + mock__start_num.assert_called_once_with(num, expected_group_name) + mock_create_task.assert_called_once_with(mock_start_num_coroutine) + self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[expected_group_name]) @patch.object(pool.SimpleTaskPool, 'cancel') def test_stop(self, mock_cancel: MagicMock): diff --git a/usage/USAGE.md b/usage/USAGE.md index 3ce38e9..9ed84a3 100644 --- a/usage/USAGE.md +++ b/usage/USAGE.md @@ -39,9 +39,9 @@ async def work(n: int) -> None: async def main() -> None: pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet - await pool.start(3) # launches work tasks 0, 1, and 2 + pool.start(3) # launches work tasks 0, 1, and 2 await asyncio.sleep(1.5) # lets the tasks work for a bit - await pool.start(1) # launches work task 3 + pool.start(1) # launches work task 3 await asyncio.sleep(1.5) # lets the tasks work for a bit pool.stop(2) # cancels tasks 3 and 2 (LIFO order) await pool.gather_and_close() # awaits all tasks, then flushes the pool diff --git a/usage/example_server.py b/usage/example_server.py index 454c413..baffb98 100644 --- a/usage/example_server.py +++ b/usage/example_server.py @@ -67,7 +67,7 @@ async def main() -> None: for item in range(100): q.put_nowait(item) pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool - await pool.start(3) # launches three worker tasks + pool.start(3) # launches three worker tasks control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever() # We block until `.task_done()` has been called once by our workers for every item placed into the queue. await q.join()