full test coverage for BaseTaskPool

This commit is contained in:
Daniil Fajnberg 2022-02-07 14:23:49 +01:00
parent e8e13406ea
commit ac903d9be7
4 changed files with 274 additions and 22 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@
/dist/ /dist/
# Python cache: # Python cache:
__pycache__/ __pycache__/
# Testing:
.coverage

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.0.1 version = 0.0.2
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks

View File

@ -210,9 +210,9 @@ class BaseTaskPool:
if not (self.is_open or ignore_closed): if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks") raise exceptions.PoolIsClosed("Cannot start new tasks")
await self._enough_room.acquire() await self._enough_room.acquire()
try:
task_id = self._counter task_id = self._counter
self._counter += 1 self._counter += 1
try:
self._running[task_id] = create_task( self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id) name=self._task_name(task_id)
@ -243,16 +243,6 @@ class BaseTaskPool:
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
def _cancel_task(self, task_id: int, msg: str = None) -> None:
"""
Cancels the running task with the specified ID.
Args:
task_id: The ID of a task running within the pool that should be cancelled.
msg (optional): Passed to the `Task.cancel()` method.
"""
self._get_running_task(task_id).cancel(msg=msg)
def cancel(self, *task_ids: int, msg: str = None) -> None: def cancel(self, *task_ids: int, msg: str = None) -> None:
""" """
Cancels the tasks with the specified IDs. Cancels the tasks with the specified IDs.
@ -303,6 +293,7 @@ class BaseTaskPool:
""" """
results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions)
self._ended = self._cancelled = {} self._ended = self._cancelled = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results
@ -338,6 +329,7 @@ class BaseTaskPool:
results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(),
return_exceptions=return_exceptions) return_exceptions=return_exceptions)
self._ended = self._cancelled = self._running = {} self._ended = self._cancelled = self._running = {}
if self._interrupt_flag.is_set():
self._interrupt_flag.clear() self._interrupt_flag.clear()
return results return results

View File

@ -1,14 +1,31 @@
import asyncio import asyncio
from unittest import TestCase from asyncio.exceptions import CancelledError
from unittest.mock import PropertyMock, patch from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from asyncio_taskpool import pool from asyncio_taskpool import pool, exceptions
EMPTY_LIST, EMPTY_DICT = [], {} EMPTY_LIST, EMPTY_DICT = [], {}
FOO, BAR = 'foo', 'bar'
class BaseTaskPoolTestCase(TestCase): class TestException(Exception):
pass
class BaseTaskPoolTestCase(IsolatedAsyncioTestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = pool.log.level
pool.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
pool.log.setLevel(cls.log_lvl)
def setUp(self) -> None: def setUp(self) -> None:
self._pools = getattr(pool.BaseTaskPool, '_pools') self._pools = getattr(pool.BaseTaskPool, '_pools')
@ -78,15 +95,15 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(new_size, self.task_pool._pool_size) self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self): def test_is_open(self):
self.task_pool._open = foo = 'foo' self.task_pool._open = FOO
self.assertEqual(foo, self.task_pool.is_open) self.assertEqual(FOO, self.task_pool.is_open)
def test_num_running(self): def test_num_running(self):
self.task_pool._running = ['foo', 'bar', 'baz'] self.task_pool._running = ['foo', 'bar', 'baz']
self.assertEqual(3, self.task_pool.num_running) self.assertEqual(3, self.task_pool.num_running)
def test_num_cancelled(self): def test_num_cancelled(self):
self.task_pool._num_cancelled = 33 self.task_pool._num_cancelled = 3
self.assertEqual(3, self.task_pool.num_cancelled) self.assertEqual(3, self.task_pool.num_cancelled)
def test_num_ended(self): def test_num_ended(self):
@ -105,3 +122,244 @@ class BaseTaskPoolTestCase(TestCase):
def test__task_name(self): def test__task_name(self):
i = 123 i = 123
self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i))
@patch.object(pool, '_execute_function')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_cancelled = cancelled = 3
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._cancelled[task_id])
self.assertEqual(cancelled + 1, self.task_pool._num_cancelled)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool, '_execute_function')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock):
task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock()
self.task_pool._num_ended = ended = 3
self.task_pool._enough_room._value = room = 123
# End running task:
self.task_pool._running[task_id] = mock_task
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._running)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 1, self.task_pool._num_ended)
self.assertEqual(room + 1, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
mock__task_name.reset_mock()
mock__execute_function.reset_mock()
# End cancelled task:
self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id)
self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback))
self.assertNotIn(task_id, self.task_pool._cancelled)
self.assertEqual(mock_task, self.task_pool._ended[task_id])
self.assertEqual(ended + 2, self.task_pool._num_ended)
self.assertEqual(room + 2, self.task_pool._enough_room._value)
mock__task_name.assert_called_with(task_id)
mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, ))
@patch.object(pool.BaseTaskPool, '_task_ending')
@patch.object(pool.BaseTaskPool, '_task_cancellation')
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
async def test__task_wrapper(self, mock__task_name: MagicMock,
mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock):
task_id = 42
mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock()
mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError)
# Cancelled during execution:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertIsNone(output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb)
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
mock_coroutine_func.reset_mock(side_effect=True)
mock__task_name.reset_mock()
mock__task_cancellation.reset_mock()
mock__task_ending.reset_mock()
# Not cancelled:
mock_awaitable = mock_coroutine_func()
output = await self.task_pool._task_wrapper(mock_awaitable, task_id,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(FOO, output)
mock_coroutine_func.assert_awaited_once()
mock__task_name.assert_called_with(task_id)
mock__task_cancellation.assert_not_awaited()
mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb)
@patch.object(pool, 'create_task')
@patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
@patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock)
async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock,
mock__task_wrapper: AsyncMock, mock_create_task: MagicMock):
def reset_mocks() -> None:
mock_is_open.reset_mock()
mock__task_name.reset_mock()
mock__task_wrapper.reset_mock()
mock_create_task.reset_mock()
mock_create_task.return_value = mock_task = MagicMock()
mock__task_wrapper.return_value = mock_wrapped = MagicMock()
mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock()
self.task_pool._counter = count = 123
self.task_pool._enough_room._value = room = 123
mock_is_open.return_value = ignore_closed = False
with self.assertRaises(exceptions.PoolIsClosed):
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_not_called()
mock__task_wrapper.assert_not_called()
mock_create_task.assert_not_called()
reset_mocks()
ignore_closed = True
output = await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(count, output)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertEqual(mock_task, self.task_pool._running[count])
self.assertEqual(room - 1, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
reset_mocks()
self.task_pool._counter = count
self.task_pool._enough_room._value = room
del self.task_pool._running[count]
mock_create_task.side_effect = test_exception = TestException()
with self.assertRaises(TestException) as e:
await self.task_pool._start_task(mock_awaitable, ignore_closed,
end_callback=mock_end_cb, cancel_callback=mock_cancel_cb)
self.assertEqual(test_exception, e)
self.assertEqual(count + 1, self.task_pool._counter)
self.assertNotIn(count, self.task_pool._running)
self.assertEqual(room, self.task_pool._enough_room._value)
mock_is_open.assert_called_once_with()
mock__task_name.assert_called_once_with(count)
mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb)
mock_create_task.assert_called_once_with(mock_wrapped, name=FOO)
@patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO)
def test__get_running_task(self, mock__task_name: MagicMock):
task_id, mock_task = 555, MagicMock()
self.task_pool._running[task_id] = mock_task
output = self.task_pool._get_running_task(task_id)
self.assertEqual(mock_task, output)
self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id)
with self.assertRaises(exceptions.AlreadyCancelled):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id)
with self.assertRaises(exceptions.TaskEnded):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_called_once_with(task_id)
mock__task_name.reset_mock()
del self.task_pool._ended[task_id]
with self.assertRaises(exceptions.InvalidTaskID):
self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock):
task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
def test_cancel_all(self):
mock_task1, mock_task2 = MagicMock(), MagicMock()
self.task_pool._running = {1: mock_task1, 2: mock_task2}
assert not self.task_pool._interrupt_flag.is_set()
self.assertIsNone(self.task_pool.cancel_all(FOO))
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_task1.cancel.assert_called_once_with(msg=FOO)
mock_task2.cancel.assert_called_once_with(msg=FOO)
async def test_flush(self):
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._interrupt_flag.set()
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
output = await self.task_pool.flush(return_exceptions=True)
self.assertListEqual([FOO, test_exception], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
def test_close(self):
assert self.task_pool._open
self.task_pool.close()
self.assertFalse(self.task_pool._open)
async def test_gather(self):
mock_wait = AsyncMock()
self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait)
test_exception = TestException()
mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception)
mock_running_func = AsyncMock(return_value=BAR)
self.task_pool._ended = ended = {123: mock_ended_func()}
self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._running = running = {789: mock_running_func()}
self.task_pool._interrupt_flag.set()
assert self.task_pool._open
with self.assertRaises(exceptions.PoolStillOpen):
await self.task_pool.gather()
self.assertDictEqual(self.task_pool._ended, ended)
self.assertDictEqual(self.task_pool._cancelled, cancelled)
self.assertDictEqual(self.task_pool._running, running)
self.assertTrue(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_not_awaited()
self.task_pool._open = False
def check_assertions() -> None:
self.assertListEqual([FOO, test_exception, BAR], output)
self.assertDictEqual(self.task_pool._ended, EMPTY_DICT)
self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT)
self.assertDictEqual(self.task_pool._running, EMPTY_DICT)
self.assertFalse(self.task_pool._interrupt_flag.is_set())
mock_wait.assert_awaited_once_with()
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()
mock_wait.reset_mock()
self.task_pool._ended = {123: mock_ended_func()}
self.task_pool._cancelled = {456: mock_cancelled_func()}
self.task_pool._running = {789: mock_running_func()}
output = await self.task_pool.gather(return_exceptions=True)
check_assertions()