made `start` "non-async" using meta task

This commit is contained in:
Daniil Fajnberg 2022-03-30 21:51:19 +02:00
parent a72a7cc516
commit 80fc91ec47
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
6 changed files with 68 additions and 35 deletions

View File

@ -32,7 +32,7 @@ async def work(_foo, _bar): ...
async def main():
pool = SimpleTaskPool(work, args=('xyz', 420))
await pool.start(5)
pool.start(5)
...
pool.stop(3)
...

View File

@ -147,9 +147,11 @@ Or we could use a task pool:
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`)
.. note::
The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`)
Neither :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` nor :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` return coroutines. When they are called, the task pool immediately begins scheduling new tasks to run. No :code:`await` needed.
It can't get any simpler than that, can it? So glad you asked...
@ -168,7 +170,7 @@ Let's take the :ref:`queue worker example <queue-worker-function>` from before.
async def main():
...
pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out))
await pool.start(5)
pool.start(5)
...
pool.stop_all()
...
@ -193,9 +195,9 @@ This may, at first glance, not seem like much of a difference, aside from differ
if some_condition and pool.num_running > 10:
pool.stop(3)
elif some_other_condition and pool.num_running < 5:
await pool.start(5)
pool.start(5)
else:
await pool.start(1)
pool.start(1)
...
await pool.gather_and_close()
@ -228,6 +230,4 @@ The only method of a pool that one should **always** assume to be blocking is :p
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased).
:py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since they make use of "meta-tasks" under the hood, they will always return immediately. However, if the pool was full when one of them was called, there is **no guarantee** that even a single task has started, when the method returns.
All methods that add tasks to a pool, i.e. :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants), :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>`, are non-blocking by design. They all make use of "meta tasks" under the hood and return immediately. It is important however, to realize that just because they return, does not mean that any actual tasks have been spawned. For example, if a pool size limit was set and there was "no more room" in the pool when :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` was called, there is **no guarantee** that even a single task has started, when it returns.

View File

@ -593,6 +593,7 @@ class TaskPool(BaseTaskPool):
"""
if kwargs is None:
kwargs = {}
# TODO: Add exception logging
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
cancel_callback=cancel_callback) for _ in range(num)))
@ -610,8 +611,8 @@ class TaskPool(BaseTaskPool):
because this method returns immediately, this does not mean that any task was started or that any number of
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
If the entire task group is cancelled, the meta task is cancelled first, which may cause the number of tasks
spawned to be less than `num`.
If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first,
the number of tasks spawned will end up being less than `num`.
Args:
func:
@ -640,10 +641,13 @@ class TaskPool(BaseTaskPool):
`PoolIsClosed`: The pool is closed.
`NotCoroutine`: `func` is not a coroutine function.
`PoolIsLocked`: The pool is currently locked.
`InvalidGroupName`: A group named `group_name` exists in the pool.
"""
self._check_start(function=func)
if group_name is None:
group_name = self._generate_group_name('apply', func)
if group_name in self._task_groups.keys():
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._apply_num(group_name, func, args, kwargs, num,
@ -913,16 +917,25 @@ class SimpleTaskPool(BaseTaskPool):
"""Name of the coroutine function used in the pool."""
return self._func.__name__
async def _start_one(self, group_name: str) -> int:
"""Starts a single new task within the pool and returns its ID."""
return await self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
async def _start_num(self, num: int, group_name: str) -> None:
"""Starts `num` new tasks in group `group_name`."""
start_coroutines = (
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
for _ in range(num)
)
await gather(*start_coroutines)
async def start(self, num: int) -> str:
def start(self, num: int) -> str:
"""
Starts specified number of new tasks in the pool as a new group.
This method may block if there is less room in the pool than the desired number of new tasks.
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
because this method returns immediately, this does not mean that any task was started or that any number of
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first,
the number of tasks spawned will end up being less than `num`.
Args:
num: The number of new tasks to start.
@ -931,9 +944,12 @@ class SimpleTaskPool(BaseTaskPool):
The name of the newly created task group in the form :code:`'start-group-{idx}'`
(with `idx` being an incrementing index).
"""
self._check_start(function=self._func)
group_name = f'start-group-{self._start_calls}'
self._start_calls += 1
await gather(*(self._start_one(group_name) for _ in range(num)))
self._task_groups.setdefault(group_name, TaskGroupRegister())
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
meta_tasks.add(create_task(self._start_num(num, group_name)))
return group_name
def stop(self, num: int) -> List[int]:

View File

@ -711,25 +711,42 @@ class SimpleTaskPoolTestCase(CommonTestCase):
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
@patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 99
self.task_pool._func = MagicMock(return_value=BAR)
async def test__start_num(self, mock__start_task: AsyncMock):
fake_coroutine = object()
self.task_pool._func = MagicMock(return_value=fake_coroutine)
num = 3
group_name = FOO + BAR + 'abc'
output = await self.task_pool._start_one(group_name)
self.assertEqual(expected_output, output)
self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs)
mock__start_task.assert_awaited_once_with(BAR, group_name=group_name, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(num * [
call(*self.task_pool._args, **self.task_pool._kwargs)
])
mock__start_task.assert_has_awaits(num * [
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
])
@patch.object(pool.SimpleTaskPool, '_start_one')
async def test_start(self, mock__start_one: AsyncMock):
mock__start_one.return_value = FOO
@patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start')
def test_start(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__start_num: AsyncMock,
mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__start_num.return_value = mock_start_num_coroutine = object()
mock_create_task.return_value = fake_task = object()
self.task_pool._task_groups = {}
self.task_pool._group_meta_tasks_running = {}
num = 5
self.task_pool._start_calls = 42
output = await self.task_pool.start(num)
expected_output = 'start-group-42'
self.assertEqual(expected_output, output)
mock__start_one.assert_has_awaits(num * [call(expected_output)])
expected_group_name = 'start-group-42'
output = self.task_pool.start(num)
self.assertEqual(expected_group_name, output)
mock__check_start.assert_called_once_with(function=self.TEST_POOL_FUNC)
self.assertEqual(43, self.task_pool._start_calls)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[expected_group_name])
mock__start_num.assert_called_once_with(num, expected_group_name)
mock_create_task.assert_called_once_with(mock_start_num_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[expected_group_name])
@patch.object(pool.SimpleTaskPool, 'cancel')
def test_stop(self, mock_cancel: MagicMock):

View File

@ -39,9 +39,9 @@ async def work(n: int) -> None:
async def main() -> None:
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2
pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start(1) # launches work task 3
pool.start(1) # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
await pool.gather_and_close() # awaits all tasks, then flushes the pool

View File

@ -67,7 +67,7 @@ async def main() -> None:
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
await pool.start(3) # launches three worker tasks
pool.start(3) # launches three worker tasks
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()