diff --git a/src/asyncio_taskpool/helpers.py b/src/asyncio_taskpool/helpers.py index 4bc2ffb..912c975 100644 --- a/src/asyncio_taskpool/helpers.py +++ b/src/asyncio_taskpool/helpers.py @@ -19,7 +19,6 @@ Miscellaneous helper functions. """ -import re from asyncio.coroutines import iscoroutinefunction from asyncio.queues import Queue from inspect import getdoc @@ -57,7 +56,7 @@ def tasks_str(num: int) -> str: def get_first_doc_line(obj: object) -> str: - return getdoc(obj).strip().split("\n", 1)[0] + return getdoc(obj).strip().split("\n", 1)[0].strip() async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]: diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 9c3e551..ee41d1e 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -94,3 +94,35 @@ class HelpersTestCase(IsolatedAsyncioTestCase): self.assertEqual("tasks", helpers.tasks_str(2)) self.assertEqual("tasks", helpers.tasks_str(-10)) self.assertEqual("tasks", helpers.tasks_str(42)) + + def test_get_first_doc_line(self): + expected_output = 'foo bar baz' + mock_obj = MagicMock(__doc__=f"""{expected_output} + something else + + even more + """) + output = helpers.get_first_doc_line(mock_obj) + self.assertEqual(expected_output, output) + + async def test_return_or_exception(self): + expected_output = '420' + mock_func = AsyncMock(return_value=expected_output) + args = (1, 3, 5) + kwargs = {'a': 1, 'b': 2, 'c': 'foo'} + output = await helpers.return_or_exception(mock_func, *args, **kwargs) + self.assertEqual(expected_output, output) + mock_func.assert_awaited_once_with(*args, **kwargs) + + mock_func = MagicMock(return_value=expected_output) + output = await helpers.return_or_exception(mock_func, *args, **kwargs) + self.assertEqual(expected_output, output) + mock_func.assert_called_once_with(*args, **kwargs) + + class TestException(Exception): + pass + test_exception = TestException() + mock_func = MagicMock(side_effect=test_exception) + output = await helpers.return_or_exception(mock_func, *args, **kwargs) + self.assertEqual(test_exception, output) + mock_func.assert_called_once_with(*args, **kwargs)