diff --git a/.gitignore b/.gitignore index 22ce824..92c90e5 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ /dist/ # Python cache: __pycache__/ +# Testing: +.coverage diff --git a/setup.cfg b/setup.cfg index f0a9d25..97fce9f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.0.1 +version = 0.0.2 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 0b2ec46..85d6c17 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -210,9 +210,9 @@ class BaseTaskPool: if not (self.is_open or ignore_closed): raise exceptions.PoolIsClosed("Cannot start new tasks") await self._enough_room.acquire() + task_id = self._counter + self._counter += 1 try: - task_id = self._counter - self._counter += 1 self._running[task_id] = create_task( self._task_wrapper(awaitable, task_id, end_callback, cancel_callback), name=self._task_name(task_id) @@ -243,16 +243,6 @@ class BaseTaskPool: raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running") raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}") - def _cancel_task(self, task_id: int, msg: str = None) -> None: - """ - Cancels the running task with the specified ID. - - Args: - task_id: The ID of a task running within the pool that should be cancelled. - msg (optional): Passed to the `Task.cancel()` method. - """ - self._get_running_task(task_id).cancel(msg=msg) - def cancel(self, *task_ids: int, msg: str = None) -> None: """ Cancels the tasks with the specified IDs. @@ -303,7 +293,8 @@ class BaseTaskPool: """ results = await gather(*self._ended.values(), *self._cancelled.values(), return_exceptions=return_exceptions) self._ended = self._cancelled = {} - self._interrupt_flag.clear() + if self._interrupt_flag.is_set(): + self._interrupt_flag.clear() return results def close(self) -> None: @@ -338,7 +329,8 @@ class BaseTaskPool: results = await gather(*self._ended.values(), *self._cancelled.values(), *self._running.values(), return_exceptions=return_exceptions) self._ended = self._cancelled = self._running = {} - self._interrupt_flag.clear() + if self._interrupt_flag.is_set(): + self._interrupt_flag.clear() return results diff --git a/tests/test_pool.py b/tests/test_pool.py index b2340b2..bf58573 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,14 +1,31 @@ import asyncio -from unittest import TestCase -from unittest.mock import PropertyMock, patch +from asyncio.exceptions import CancelledError +from unittest import IsolatedAsyncioTestCase +from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call -from asyncio_taskpool import pool +from asyncio_taskpool import pool, exceptions EMPTY_LIST, EMPTY_DICT = [], {} +FOO, BAR = 'foo', 'bar' -class BaseTaskPoolTestCase(TestCase): +class TestException(Exception): + pass + + +class BaseTaskPoolTestCase(IsolatedAsyncioTestCase): + log_lvl: int + + @classmethod + def setUpClass(cls) -> None: + cls.log_lvl = pool.log.level + pool.log.setLevel(999) + + @classmethod + def tearDownClass(cls) -> None: + pool.log.setLevel(cls.log_lvl) + def setUp(self) -> None: self._pools = getattr(pool.BaseTaskPool, '_pools') @@ -78,15 +95,15 @@ class BaseTaskPoolTestCase(TestCase): self.assertEqual(new_size, self.task_pool._pool_size) def test_is_open(self): - self.task_pool._open = foo = 'foo' - self.assertEqual(foo, self.task_pool.is_open) + self.task_pool._open = FOO + self.assertEqual(FOO, self.task_pool.is_open) def test_num_running(self): self.task_pool._running = ['foo', 'bar', 'baz'] self.assertEqual(3, self.task_pool.num_running) def test_num_cancelled(self): - self.task_pool._num_cancelled = 33 + self.task_pool._num_cancelled = 3 self.assertEqual(3, self.task_pool.num_cancelled) def test_num_ended(self): @@ -105,3 +122,244 @@ class BaseTaskPoolTestCase(TestCase): def test__task_name(self): i = 123 self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) + + @patch.object(pool, '_execute_function') + @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) + async def test__task_cancellation(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): + task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() + self.task_pool._num_cancelled = cancelled = 3 + self.task_pool._running[task_id] = mock_task + self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback)) + self.assertNotIn(task_id, self.task_pool._running) + self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) + self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) + mock__task_name.assert_called_with(task_id) + mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + + @patch.object(pool, '_execute_function') + @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) + async def test__task_ending(self, mock__task_name: MagicMock, mock__execute_function: AsyncMock): + task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() + self.task_pool._num_ended = ended = 3 + self.task_pool._enough_room._value = room = 123 + + # End running task: + self.task_pool._running[task_id] = mock_task + self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) + self.assertNotIn(task_id, self.task_pool._running) + self.assertEqual(mock_task, self.task_pool._ended[task_id]) + self.assertEqual(ended + 1, self.task_pool._num_ended) + self.assertEqual(room + 1, self.task_pool._enough_room._value) + mock__task_name.assert_called_with(task_id) + mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + mock__task_name.reset_mock() + mock__execute_function.reset_mock() + + # End cancelled task: + self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) + self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) + self.assertNotIn(task_id, self.task_pool._cancelled) + self.assertEqual(mock_task, self.task_pool._ended[task_id]) + self.assertEqual(ended + 2, self.task_pool._num_ended) + self.assertEqual(room + 2, self.task_pool._enough_room._value) + mock__task_name.assert_called_with(task_id) + mock__execute_function.assert_awaited_once_with(mock_callback, args=(task_id, )) + + @patch.object(pool.BaseTaskPool, '_task_ending') + @patch.object(pool.BaseTaskPool, '_task_cancellation') + @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) + async def test__task_wrapper(self, mock__task_name: MagicMock, + mock__task_cancellation: AsyncMock, mock__task_ending: AsyncMock): + task_id = 42 + mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock() + mock_coroutine_func = AsyncMock(return_value=FOO, side_effect=CancelledError) + + # Cancelled during execution: + mock_awaitable = mock_coroutine_func() + output = await self.task_pool._task_wrapper(mock_awaitable, task_id, + end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertIsNone(output) + mock_coroutine_func.assert_awaited_once() + mock__task_name.assert_called_with(task_id) + mock__task_cancellation.assert_awaited_once_with(task_id, custom_callback=mock_cancel_cb) + mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb) + + mock_coroutine_func.reset_mock(side_effect=True) + mock__task_name.reset_mock() + mock__task_cancellation.reset_mock() + mock__task_ending.reset_mock() + + # Not cancelled: + mock_awaitable = mock_coroutine_func() + output = await self.task_pool._task_wrapper(mock_awaitable, task_id, + end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertEqual(FOO, output) + mock_coroutine_func.assert_awaited_once() + mock__task_name.assert_called_with(task_id) + mock__task_cancellation.assert_not_awaited() + mock__task_ending.assert_awaited_once_with(task_id, custom_callback=mock_end_cb) + + @patch.object(pool, 'create_task') + @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) + @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) + @patch.object(pool.BaseTaskPool, 'is_open', new_callable=PropertyMock) + async def test__start_task(self, mock_is_open: MagicMock, mock__task_name: MagicMock, + mock__task_wrapper: AsyncMock, mock_create_task: MagicMock): + def reset_mocks() -> None: + mock_is_open.reset_mock() + mock__task_name.reset_mock() + mock__task_wrapper.reset_mock() + mock_create_task.reset_mock() + + mock_create_task.return_value = mock_task = MagicMock() + mock__task_wrapper.return_value = mock_wrapped = MagicMock() + mock_awaitable, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() + self.task_pool._counter = count = 123 + self.task_pool._enough_room._value = room = 123 + + mock_is_open.return_value = ignore_closed = False + with self.assertRaises(exceptions.PoolIsClosed): + await self.task_pool._start_task(mock_awaitable, ignore_closed, + end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertEqual(count, self.task_pool._counter) + self.assertNotIn(count, self.task_pool._running) + self.assertEqual(room, self.task_pool._enough_room._value) + mock_is_open.assert_called_once_with() + mock__task_name.assert_not_called() + mock__task_wrapper.assert_not_called() + mock_create_task.assert_not_called() + reset_mocks() + + ignore_closed = True + output = await self.task_pool._start_task(mock_awaitable, ignore_closed, + end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertEqual(count, output) + self.assertEqual(count + 1, self.task_pool._counter) + self.assertEqual(mock_task, self.task_pool._running[count]) + self.assertEqual(room - 1, self.task_pool._enough_room._value) + mock_is_open.assert_called_once_with() + mock__task_name.assert_called_once_with(count) + mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) + mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) + reset_mocks() + self.task_pool._counter = count + self.task_pool._enough_room._value = room + del self.task_pool._running[count] + + mock_create_task.side_effect = test_exception = TestException() + with self.assertRaises(TestException) as e: + await self.task_pool._start_task(mock_awaitable, ignore_closed, + end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) + self.assertEqual(test_exception, e) + self.assertEqual(count + 1, self.task_pool._counter) + self.assertNotIn(count, self.task_pool._running) + self.assertEqual(room, self.task_pool._enough_room._value) + mock_is_open.assert_called_once_with() + mock__task_name.assert_called_once_with(count) + mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) + mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) + + @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) + def test__get_running_task(self, mock__task_name: MagicMock): + task_id, mock_task = 555, MagicMock() + self.task_pool._running[task_id] = mock_task + output = self.task_pool._get_running_task(task_id) + self.assertEqual(mock_task, output) + + self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id) + with self.assertRaises(exceptions.AlreadyCancelled): + self.task_pool._get_running_task(task_id) + mock__task_name.assert_called_once_with(task_id) + mock__task_name.reset_mock() + + self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id) + with self.assertRaises(exceptions.TaskEnded): + self.task_pool._get_running_task(task_id) + mock__task_name.assert_called_once_with(task_id) + mock__task_name.reset_mock() + + del self.task_pool._ended[task_id] + with self.assertRaises(exceptions.InvalidTaskID): + self.task_pool._get_running_task(task_id) + mock__task_name.assert_not_called() + + @patch.object(pool.BaseTaskPool, '_get_running_task') + def test_cancel(self, mock__get_running_task: MagicMock): + task_id1, task_id2, task_id3 = 1, 4, 9 + mock__get_running_task.return_value.cancel = mock_cancel = MagicMock() + self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO)) + mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) + mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) + + def test_cancel_all(self): + mock_task1, mock_task2 = MagicMock(), MagicMock() + self.task_pool._running = {1: mock_task1, 2: mock_task2} + assert not self.task_pool._interrupt_flag.is_set() + self.assertIsNone(self.task_pool.cancel_all(FOO)) + self.assertTrue(self.task_pool._interrupt_flag.is_set()) + mock_task1.cancel.assert_called_once_with(msg=FOO) + mock_task2.cancel.assert_called_once_with(msg=FOO) + + async def test_flush(self): + test_exception = TestException() + mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) + self.task_pool._ended = {123: mock_ended_func()} + self.task_pool._cancelled = {456: mock_cancelled_func()} + self.task_pool._interrupt_flag.set() + output = await self.task_pool.flush(return_exceptions=True) + self.assertListEqual([FOO, test_exception], output) + self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) + self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) + self.assertFalse(self.task_pool._interrupt_flag.is_set()) + + self.task_pool._ended = {123: mock_ended_func()} + self.task_pool._cancelled = {456: mock_cancelled_func()} + output = await self.task_pool.flush(return_exceptions=True) + self.assertListEqual([FOO, test_exception], output) + self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) + self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) + + def test_close(self): + assert self.task_pool._open + self.task_pool.close() + self.assertFalse(self.task_pool._open) + + async def test_gather(self): + mock_wait = AsyncMock() + self.task_pool._all_tasks_known_flag = MagicMock(wait=mock_wait) + test_exception = TestException() + mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) + mock_running_func = AsyncMock(return_value=BAR) + self.task_pool._ended = ended = {123: mock_ended_func()} + self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()} + self.task_pool._running = running = {789: mock_running_func()} + self.task_pool._interrupt_flag.set() + + assert self.task_pool._open + with self.assertRaises(exceptions.PoolStillOpen): + await self.task_pool.gather() + self.assertDictEqual(self.task_pool._ended, ended) + self.assertDictEqual(self.task_pool._cancelled, cancelled) + self.assertDictEqual(self.task_pool._running, running) + self.assertTrue(self.task_pool._interrupt_flag.is_set()) + mock_wait.assert_not_awaited() + + self.task_pool._open = False + + def check_assertions() -> None: + self.assertListEqual([FOO, test_exception, BAR], output) + self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) + self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) + self.assertDictEqual(self.task_pool._running, EMPTY_DICT) + self.assertFalse(self.task_pool._interrupt_flag.is_set()) + mock_wait.assert_awaited_once_with() + + output = await self.task_pool.gather(return_exceptions=True) + check_assertions() + mock_wait.reset_mock() + + self.task_pool._ended = {123: mock_ended_func()} + self.task_pool._cancelled = {456: mock_cancelled_func()} + self.task_pool._running = {789: mock_running_func()} + output = await self.task_pool.gather(return_exceptions=True) + check_assertions()