fixed potential race cond. gathering meta tasks

This commit is contained in:
2022-03-25 12:58:18 +01:00
parent 7e34aa106d
commit a9011076c4
4 changed files with 34 additions and 47 deletions

View File

@ -81,12 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(ValueError):
helpers.star_function(f, a, 123456789)
async def test_join_queue(self):
mock_join = AsyncMock()
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_get_first_doc_line(self):
expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output}

View File

@ -93,7 +93,6 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
@ -366,9 +365,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
async def test_gather_and_close(self):
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
mock_running_func = AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()}
@ -378,19 +376,16 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_before_gather.assert_awaited_once_with()
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._closed)
@ -623,19 +618,32 @@ class TaskPoolTestCase(CommonTestCase):
call(mock_func, bad, arg_stars=stars)
])
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock()
# With a CancelledError thrown while starting a task:
mock_semaphore_cls.return_value = semaphore = Semaphore(1)
mock_star_function.side_effect = CancelledError()
mock_q = MagicMock(__aenter__=AsyncMock(return_value=arg1), __aexit__=AsyncMock(), maxsize=mock_q_maxsize)
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
self.assertFalse(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_not_called()
mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars)
@patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool, 'Queue')
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start')
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
mock_create_task: MagicMock):
mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_queue_cls.return_value = mock_q = MagicMock()
mock_join_queue.return_value = fake_join = object()
mock__queue_producer.return_value = fake_producer = object()
mock__queue_consumer.return_value = fake_consumer = object()
fake_task1, fake_task2 = object(), object()
@ -669,8 +677,6 @@ class TaskPoolTestCase(CommonTestCase):
self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with()
mock_queue_cls.assert_called_once_with(maxsize=group_size)
mock_join_queue.assert_called_once_with(mock_q)
self.assertListEqual([fake_join], self.task_pool._before_gathering)
mock__queue_producer.assert_called_once()
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])