Compare commits

...

3 Commits

5 changed files with 398 additions and 51 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.1.3
version = 0.1.6
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -3,6 +3,6 @@ MSG_BYTES = 1024
CMD_START = 'start'
CMD_STOP = 'stop'
CMD_STOP_ALL = 'stop_all'
CMD_SIZE = 'size'
CMD_NUM_RUNNING = 'num_running'
CMD_FUNC = 'func'
CLIENT_EXIT = 'exit'

View File

@ -498,25 +498,47 @@ class TaskPool(BaseTaskPool):
await self._queue_consumer(q, func, arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
await execute_optional(end_callback, args=(task_id,))
def _fill_args_queue(self, q: Queue, args_iter: ArgsT, num_tasks: int) -> int:
def _set_up_args_queue(self, args_iter: ArgsT, num_tasks: int) -> Queue:
"""
Helper function for `_map()`.
Takes the iterable of function arguments `args_iter` and adds up to `num_tasks` to a new `asyncio.Queue`.
The queue's `join()` method is added to the pool's `_before_gathering` list and the queue is returned.
If the iterable contains less than `num_tasks` elements, nothing else happens; otherwise the `_queue_producer`
is started as a separate task with the arguments queue and and iterator of the remaining arguments.
Args:
args_iter:
The iterable of function arguments passed into `_map()` to use for creating the new tasks.
num_tasks:
The maximum number of the new tasks to run concurrently that was passed into `_map()`.
Returns:
The newly created and filled arguments queue for spawning new tasks.
"""
# Setting the `maxsize` of the queue to `num_tasks` will ensure that no more than `num_tasks` tasks will run
# concurrently because the size of the queue is what will determine the number of immediately started tasks in
# the `_map()` method and each of those will only ever start (at most) one other task upon ending.
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
args_iter = iter(args_iter)
try:
# Here we guarantee that the queue will contain as many arguments as needed for starting the first batch of
# tasks, which will be at most `num_tasks` (meaning the queue will be full).
for i in range(num_tasks):
q.put_nowait(next(args_iter))
args_queue.put_nowait(next(args_iter))
except StopIteration:
# If we get here, this means that the number of elements in the arguments iterator was less than the
# specified `num_tasks`. Thus, the number of tasks to start immediately will be the size of the queue.
# specified `num_tasks`. Still, the number of tasks to start immediately will be the size of the queue.
# The `_queue_producer` won't be necessary, since we already put all the elements in the queue.
num_tasks = q.qsize()
pass
else:
# There may be more elements in the arguments iterator, so we need the `_queue_producer`.
# It will have exclusive access to the `args_iter` from now on.
# If the queue is full already, it will wait until one of the tasks in the first batch ends, before putting
# the next item in it.
create_task(self._queue_producer(q, args_iter))
return num_tasks
# Since the queue is full already, it will wait until one of the tasks in the first batch ends,
# before putting the next item in it.
create_task(self._queue_producer(args_queue, args_iter))
return args_queue
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
@ -528,8 +550,7 @@ class TaskPool(BaseTaskPool):
This method blocks, **only if** there is not enough room in the pool for the first batch of new tasks.
It sets up an internal queue which is filled while consuming the arguments iterable.
The queue's `join()` method is added to the pool's `_before_gathering` list.
It sets up an internal arguments queue which is continuously filled while consuming the arguments iterable.
Args:
func:
@ -552,10 +573,8 @@ class TaskPool(BaseTaskPool):
"""
if not self.is_open:
raise exceptions.PoolIsClosed("Cannot start new tasks")
args_queue = Queue(maxsize=num_tasks)
self._before_gathering.append(join_queue(args_queue))
num_tasks = self._fill_args_queue(args_queue, args_iter, num_tasks)
for _ in range(num_tasks):
args_queue = self._set_up_args_queue(args_iter, num_tasks)
for _ in range(args_queue.qsize()):
# This is where blocking can occur, if the pool is full.
await self._queue_consumer(args_queue, func,
arg_stars=arg_stars, end_callback=end_callback, cancel_callback=cancel_callback)
@ -619,9 +638,45 @@ class TaskPool(BaseTaskPool):
class SimpleTaskPool(BaseTaskPool):
"""
Simplified task pool class.
A `SimpleTaskPool` instance can manage an arbitrary number of concurrent tasks,
but they **must** come from a single coroutine function, called with the same arguments.
The coroutine function and its arguments are defined upon initialization.
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
However, since all tasks come from the same function-arguments-combination, the specificity of the `cancel()` method
is probably unnecessary. Instead, a simpler `stop()` method is introduced.
Adding tasks blocks **only if** the pool is full at that moment.
"""
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None:
pool_size: int = inf, name: str = None) -> None:
"""
Args:
func:
The function to use for spawning new tasks within the pool.
args (optional):
The positional arguments to pass into each function call.
kwargs (optional):
The keyword-arguments to pass into each function call.
end_callback (optional):
A callback to execute after a task has ended.
It is run with the task's ID as its only positional argument.
cancel_callback (optional):
A callback to execute after cancellation of a task.
It is run with the task's ID as its only positional argument.
pool_size (optional):
The maximum number of tasks allowed to run concurrently in the pool
name (optional):
An optional name for the pool.
"""
if not iscoroutinefunction(func):
raise exceptions.NotCoroutine(f"Not a coroutine function: {func}")
self._func: CoroutineFunc = func
@ -629,32 +684,39 @@ class SimpleTaskPool(BaseTaskPool):
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback
super().__init__(name=name)
super().__init__(pool_size=pool_size, name=name)
@property
def func_name(self) -> str:
"""Returns the name of the coroutine function used in the pool."""
return self._func.__name__
@property
def size(self) -> int:
return self.num_running
async def _start_one(self) -> int:
"""Starts a single new task within the pool and returns its ID."""
return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
async def start(self, num: int = 1) -> List[int]:
return [await self._start_one() for _ in range(num)]
"""Starts `num` new tasks within the pool and returns their IDs as a list."""
ids = await gather(*(self._start_one() for _ in range(num)))
assert isinstance(ids, list) # for PyCharm (see above to-do-item)
return ids
def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size)
"""
Cancels `num` running tasks within the pool and returns their IDs as a list.
The tasks are canceled in LIFO order, meaning tasks started later will be stopped before those started earlier.
If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled.
"""
ids = []
for i, task_id in enumerate(reversed(self._running)):
if i >= num:
break
break # We got the desired number of task IDs, there may well be more tasks left to keep running
ids.append(task_id)
self.cancel(*ids)
return ids
def stop_all(self) -> List[int]:
return self.stop(self.size)
"""Cancels all running tasks and returns their IDs as a list."""
return self.stop(self.num_running)

View File

@ -63,8 +63,8 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
writer.write(str(self._pool.stop_all()).encode())
def _pool_size(self, writer: StreamWriter) -> None:
log.debug("%s requests pool size", self.client_class.__name__)
writer.write(str(self._pool.size).encode())
log.debug("%s requests number of running tasks", self.client_class.__name__)
writer.write(str(self._pool.num_running).encode())
def _pool_func(self, writer: StreamWriter) -> None:
log.debug("%s requests pool function", self.client_class.__name__)
@ -83,7 +83,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:
self._stop_all_tasks(writer)
elif cmd == constants.CMD_SIZE:
elif cmd == constants.CMD_NUM_RUNNING:
self._pool_size(writer)
elif cmd == constants.CMD_FUNC:
self._pool_func(writer)

View File

@ -1,7 +1,9 @@
import asyncio
from asyncio.exceptions import CancelledError
from asyncio.queues import Queue
from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type
from asyncio_taskpool import pool, exceptions
@ -14,7 +16,12 @@ class TestException(Exception):
pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
class CommonTestCase(IsolatedAsyncioTestCase):
TEST_CLASS: Type[pool.BaseTaskPool] = pool.BaseTaskPool
TEST_POOL_SIZE: int = 420
TEST_POOL_NAME: str = 'test123'
task_pool: pool.BaseTaskPool
log_lvl: int
@classmethod
@ -26,35 +33,38 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools')
def get_task_pool_init_params(self) -> dict:
return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME}
# These three methods are called during initialization, so we mock them by default during setup
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
def setUp(self) -> None:
self._pools = self.TEST_CLASS._pools
# These three methods are called during initialization, so we mock them by default during setup:
self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool')
self.pool_size_patcher = patch.object(self.TEST_CLASS, 'pool_size', new_callable=PropertyMock)
self.dunder_str_patcher = patch.object(self.TEST_CLASS, '__str__')
self.mock__add_pool = self._add_pool_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start()
self.mock___str__ = self.dunder_str_patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar'
# Test pool parameters:
self.test_pool_size, self.test_pool_name = 420, 'test123'
self.task_pool = pool.BaseTaskPool(pool_size=self.test_pool_size, name=self.test_pool_name)
self.task_pool = self.TEST_CLASS(**self.get_task_pool_init_params())
def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools)
self.TEST_CLASS._pools.clear()
self._add_pool_patcher.stop()
self.pool_size_patcher.stop()
self.__str___patcher.stop()
self.dunder_str_patcher.stop()
class BaseTaskPoolTestCase(CommonTestCase):
def test__add_pool(self):
self.assertListEqual(EMPTY_LIST, self._pools)
self._add_pool_patcher.stop()
output = pool.TaskPool._add_pool(self.task_pool)
output = pool.BaseTaskPool._add_pool(self.task_pool)
self.assertEqual(0, output)
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools)
def test_init(self):
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
@ -66,26 +76,26 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.assertEqual(0, self.task_pool._num_cancelled)
self.assertEqual(0, self.task_pool._num_ended)
self.assertEqual(self.mock_idx, self.task_pool._idx)
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE)
self.mock___str__.assert_called_once_with()
def test___str__(self):
self.__str___patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.test_pool_name}'
self.dunder_str_patcher.stop()
expected_str = f'{pool.BaseTaskPool.__name__}-{self.TEST_POOL_NAME}'
self.assertEqual(expected_str, str(self.task_pool))
setattr(self.task_pool, '_name', None)
self.task_pool._name = None
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self):
self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
self.task_pool._pool_size = self.TEST_POOL_SIZE
self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
@ -377,3 +387,278 @@ class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
check_assertions(await self.task_pool.gather(return_exceptions=True))
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
@patch.object(pool.TaskPool, '_start_task')
async def test__apply_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 12345
mock_awaitable = MagicMock()
mock_func = MagicMock(return_value=mock_awaitable)
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args, **kwargs)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
mock_func.reset_mock()
mock__start_task.reset_mock()
output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock_func.assert_called_once_with(*args)
mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_apply_one')
async def test_apply(self, mock__apply_one: AsyncMock):
mock__apply_one.return_value = mock_id = 67890
mock_func, num = MagicMock(), 3
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock()
expected_output = num * [mock_id]
output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb)
self.assertEqual(expected_output, output)
mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)])
async def test__queue_producer(self):
mock_put = AsyncMock()
mock_q = MagicMock(put=mock_put)
args = (FOO, BAR, 123)
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_has_awaits([call(arg) for arg in args])
mock_put.reset_mock()
self.task_pool._interrupt_flag.set()
self.assertIsNone(await self.task_pool._queue_producer(mock_q, args))
mock_put.assert_not_awaited()
@patch.object(pool, 'partial')
@patch.object(pool, 'star_function')
@patch.object(pool.TaskPool, '_start_task')
async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock,
mock_partial: MagicMock):
mock_partial.return_value = queue_callback = 'not really'
mock_star_function.return_value = awaitable = 'totally an awaitable'
q, arg = Queue(), 420.69
q.put_nowait(arg)
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_awaited_once_with(awaitable, ignore_closed=True,
end_callback=queue_callback, cancel_callback=cancel_cb)
mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars)
mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool,
q=q, func=mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock__start_task.reset_mock()
mock_star_function.reset_mock()
mock_partial.reset_mock()
self.assertIsNone(await self.task_pool._queue_consumer(q, mock_func, stars, end_cb, cancel_cb))
self.assertTrue(q.empty())
mock__start_task.assert_not_awaited()
mock_star_function.assert_not_called()
mock_partial.assert_not_called()
@patch.object(pool, 'execute_optional')
@patch.object(pool.TaskPool, '_queue_consumer')
async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock):
task_id, mock_q = 420, MagicMock()
mock_func, stars = MagicMock(), 3
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_func, stars, end_cb, cancel_cb))
mock__queue_consumer.assert_awaited_once_with(mock_q, mock_func, stars,
end_callback=end_cb, cancel_callback=cancel_cb)
mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,))
@patch.object(pool, 'iter')
@patch.object(pool, 'create_task')
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock,
mock_create_task: MagicMock, mock_iter: MagicMock):
args, num_tasks = (FOO, BAR, 1, 2, 3), 2
mock_join_queue.return_value = mock_join = 'awaitable'
mock_iter.return_value = args_iter = iter(args)
mock__queue_producer.return_value = mock_producer_coro = 'very awaitable'
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(num_tasks, output_q.qsize())
for arg in args[:num_tasks]:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
for arg in args[num_tasks:]:
self.assertEqual(arg, next(args_iter))
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_called_once_with(output_q, args_iter)
mock_create_task.assert_called_once_with(mock_producer_coro)
self.task_pool._before_gathering.clear()
mock_join_queue.reset_mock()
mock__queue_producer.reset_mock()
mock_create_task.reset_mock()
num_tasks = 6
mock_iter.return_value = args_iter = iter(args)
output_q = self.task_pool._set_up_args_queue(args, num_tasks)
self.assertIsInstance(output_q, Queue)
self.assertEqual(len(args), output_q.qsize())
for arg in args:
self.assertEqual(arg, output_q.get_nowait())
self.assertTrue(output_q.empty())
with self.assertRaises(StopIteration):
next(args_iter)
self.assertListEqual([mock_join], self.task_pool._before_gathering)
mock_join_queue.assert_called_once_with(output_q)
mock__queue_producer.assert_not_called()
mock_create_task.assert_not_called()
@patch.object(pool.TaskPool, '_queue_consumer')
@patch.object(pool.TaskPool, '_set_up_args_queue')
@patch.object(pool.TaskPool, 'is_open', new_callable=PropertyMock)
async def test__map(self, mock_is_open: MagicMock, mock__set_up_args_queue: MagicMock,
mock__queue_consumer: AsyncMock):
qsize = 4
mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize))
mock_func, stars = MagicMock(), 3
args_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
mock_is_open.return_value = False
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb)
mock_is_open.assert_called_once_with()
mock__set_up_args_queue.assert_not_called()
mock__queue_consumer.assert_not_awaited()
mock_is_open.reset_mock()
mock_is_open.return_value = True
self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, num_tasks, end_cb, cancel_cb))
mock__set_up_args_queue.assert_called_once_with(args_iter, num_tasks)
mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_func, arg_stars=stars,
end_callback=end_cb, cancel_callback=cancel_cb)])
@patch.object(pool.TaskPool, '_map')
async def test_map(self, mock__map: AsyncMock):
mock_func = MagicMock()
arg_iter, num_tasks = (FOO, BAR, 1, 2, 3), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_starmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
args_iter, num_tasks = ([FOO], [BAR]), 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
@patch.object(pool.TaskPool, '_map')
async def test_doublestarmap(self, mock__map: AsyncMock):
mock_func = MagicMock()
kwargs_iter, num_tasks = [{'a': FOO}, {'a': BAR}], 2
end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb))
mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_cb, cancel_callback=cancel_cb)
class SimpleTaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.SimpleTaskPool
task_pool: pool.SimpleTaskPool
TEST_POOL_FUNC = AsyncMock(__name__=FOO)
TEST_POOL_ARGS = (FOO, BAR)
TEST_POOL_KWARGS = {'a': 1, 'b': 2}
TEST_POOL_END_CB = MagicMock()
TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | {
'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB,
}
def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
self.base_class_init = self.base_class_init_patcher.start()
super().setUp()
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
self.assertEqual(self.TEST_POOL_ARGS, self.task_pool._args)
self.assertEqual(self.TEST_POOL_KWARGS, self.task_pool._kwargs)
self.assertEqual(self.TEST_POOL_END_CB, self.task_pool._end_callback)
self.assertEqual(self.TEST_POOL_CANCEL_CB, self.task_pool._cancel_callback)
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
with self.assertRaises(exceptions.NotCoroutine):
pool.SimpleTaskPool(MagicMock())
def test_func_name(self):
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
@patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_one(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 99
self.task_pool._func = MagicMock(return_value=BAR)
output = await self.task_pool._start_one()
self.assertEqual(expected_output, output)
self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs)
mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback,
cancel_callback=self.task_pool._cancel_callback)
@patch.object(pool.SimpleTaskPool, '_start_one')
async def test_start(self, mock__start_one: AsyncMock):
mock__start_one.return_value = FOO
num = 5
output = await self.task_pool.start(num)
expected_output = num * [FOO]
self.assertListEqual(expected_output, output)
mock__start_one.assert_has_awaits(num * [call()])
@patch.object(pool.SimpleTaskPool, 'cancel')
def test_stop(self, mock_cancel: MagicMock):
num = 2
id1, id2, id3 = 5, 6, 7
self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR}
output = self.task_pool.stop(num)
expected_output = [id3, id2]
self.assertEqual(expected_output, output)
mock_cancel.assert_called_once_with(*expected_output)
mock_cancel.reset_mock()
num = 50
output = self.task_pool.stop(num)
expected_output = [id3, id2, id1]
self.assertEqual(expected_output, output)
mock_cancel.assert_called_once_with(*expected_output)
@patch.object(pool.SimpleTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.SimpleTaskPool, 'stop')
def test_stop_all(self, mock_stop: MagicMock, mock_num_running: MagicMock):
mock_num_running.return_value = num = 9876
mock_stop.return_value = expected_output = 'something'
output = self.task_pool.stop_all()
self.assertEqual(expected_output, output)
mock_num_running.assert_called_once_with()
mock_stop.assert_called_once_with(num)