diff --git a/setup.cfg b/setup.cfg index c75772a..692450f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = asyncio-taskpool -version = 0.1.5 +version = 0.1.6 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Dynamically manage pools of asyncio tasks diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index 8713894..6716797 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -3,6 +3,6 @@ MSG_BYTES = 1024 CMD_START = 'start' CMD_STOP = 'stop' CMD_STOP_ALL = 'stop_all' -CMD_SIZE = 'size' +CMD_NUM_RUNNING = 'num_running' CMD_FUNC = 'func' CLIENT_EXIT = 'exit' diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index a89876c..d2ffe64 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -638,9 +638,45 @@ class TaskPool(BaseTaskPool): class SimpleTaskPool(BaseTaskPool): + """ + Simplified task pool class. + + A `SimpleTaskPool` instance can manage an arbitrary number of concurrent tasks, + but they **must** come from a single coroutine function, called with the same arguments. + + The coroutine function and its arguments are defined upon initialization. + + As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.) + Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment. + However, since all tasks come from the same function-arguments-combination, the specificity of the `cancel()` method + is probably unnecessary. Instead, a simpler `stop()` method is introduced. + + Adding tasks blocks **only if** the pool is full at that moment. + """ + def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None, - name: str = None) -> None: + pool_size: int = inf, name: str = None) -> None: + """ + + Args: + func: + The function to use for spawning new tasks within the pool. + args (optional): + The positional arguments to pass into each function call. + kwargs (optional): + The keyword-arguments to pass into each function call. + end_callback (optional): + A callback to execute after a task has ended. + It is run with the task's ID as its only positional argument. + cancel_callback (optional): + A callback to execute after cancellation of a task. + It is run with the task's ID as its only positional argument. + pool_size (optional): + The maximum number of tasks allowed to run concurrently in the pool + name (optional): + An optional name for the pool. + """ if not iscoroutinefunction(func): raise exceptions.NotCoroutine(f"Not a coroutine function: {func}") self._func: CoroutineFunc = func @@ -648,32 +684,39 @@ class SimpleTaskPool(BaseTaskPool): self._kwargs: KwArgsT = kwargs if kwargs is not None else {} self._end_callback: EndCallbackT = end_callback self._cancel_callback: CancelCallbackT = cancel_callback - super().__init__(name=name) + super().__init__(pool_size=pool_size, name=name) @property def func_name(self) -> str: + """Returns the name of the coroutine function used in the pool.""" return self._func.__name__ - @property - def size(self) -> int: - return self.num_running - async def _start_one(self) -> int: + """Starts a single new task within the pool and returns its ID.""" return await self._start_task(self._func(*self._args, **self._kwargs), end_callback=self._end_callback, cancel_callback=self._cancel_callback) async def start(self, num: int = 1) -> List[int]: - return [await self._start_one() for _ in range(num)] + """Starts `num` new tasks within the pool and returns their IDs as a list.""" + ids = await gather(*(self._start_one() for _ in range(num))) + assert isinstance(ids, list) # for PyCharm (see above to-do-item) + return ids def stop(self, num: int = 1) -> List[int]: - num = min(num, self.size) + """ + Cancels `num` running tasks within the pool and returns their IDs as a list. + + The tasks are canceled in LIFO order, meaning tasks started later will be stopped before those started earlier. + If `num` is greater than or equal to the number of currently running tasks, naturally all tasks are cancelled. + """ ids = [] for i, task_id in enumerate(reversed(self._running)): if i >= num: - break + break # We got the desired number of task IDs, there may well be more tasks left to keep running ids.append(task_id) self.cancel(*ids) return ids def stop_all(self) -> List[int]: - return self.stop(self.size) + """Cancels all running tasks and returns their IDs as a list.""" + return self.stop(self.num_running) diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 8db8e98..15ff01a 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -63,8 +63,8 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta writer.write(str(self._pool.stop_all()).encode()) def _pool_size(self, writer: StreamWriter) -> None: - log.debug("%s requests pool size", self.client_class.__name__) - writer.write(str(self._pool.size).encode()) + log.debug("%s requests number of running tasks", self.client_class.__name__) + writer.write(str(self._pool.num_running).encode()) def _pool_func(self, writer: StreamWriter) -> None: log.debug("%s requests pool function", self.client_class.__name__) @@ -83,7 +83,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta self._stop_tasks(writer, arg) elif cmd == constants.CMD_STOP_ALL: self._stop_all_tasks(writer) - elif cmd == constants.CMD_SIZE: + elif cmd == constants.CMD_NUM_RUNNING: self._pool_size(writer) elif cmd == constants.CMD_FUNC: self._pool_func(writer) diff --git a/tests/test_pool.py b/tests/test_pool.py index 6e7e3fe..ac03364 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -574,3 +574,91 @@ class TaskPoolTestCase(CommonTestCase): self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, num_tasks, end_cb, cancel_cb)) mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, num_tasks=num_tasks, end_callback=end_cb, cancel_callback=cancel_cb) + + +class SimpleTaskPoolTestCase(CommonTestCase): + TEST_CLASS = pool.SimpleTaskPool + task_pool: pool.SimpleTaskPool + + TEST_POOL_FUNC = AsyncMock(__name__=FOO) + TEST_POOL_ARGS = (FOO, BAR) + TEST_POOL_KWARGS = {'a': 1, 'b': 2} + TEST_POOL_END_CB = MagicMock() + TEST_POOL_CANCEL_CB = MagicMock() + + def get_task_pool_init_params(self) -> dict: + return super().get_task_pool_init_params() | { + 'func': self.TEST_POOL_FUNC, + 'args': self.TEST_POOL_ARGS, + 'kwargs': self.TEST_POOL_KWARGS, + 'end_callback': self.TEST_POOL_END_CB, + 'cancel_callback': self.TEST_POOL_CANCEL_CB, + } + + def setUp(self) -> None: + self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__') + self.base_class_init = self.base_class_init_patcher.start() + super().setUp() + + def tearDown(self) -> None: + self.base_class_init_patcher.stop() + + def test_init(self): + self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func) + self.assertEqual(self.TEST_POOL_ARGS, self.task_pool._args) + self.assertEqual(self.TEST_POOL_KWARGS, self.task_pool._kwargs) + self.assertEqual(self.TEST_POOL_END_CB, self.task_pool._end_callback) + self.assertEqual(self.TEST_POOL_CANCEL_CB, self.task_pool._cancel_callback) + self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME) + + with self.assertRaises(exceptions.NotCoroutine): + pool.SimpleTaskPool(MagicMock()) + + def test_func_name(self): + self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name) + + @patch.object(pool.SimpleTaskPool, '_start_task') + async def test__start_one(self, mock__start_task: AsyncMock): + mock__start_task.return_value = expected_output = 99 + self.task_pool._func = MagicMock(return_value=BAR) + output = await self.task_pool._start_one() + self.assertEqual(expected_output, output) + self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs) + mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback, + cancel_callback=self.task_pool._cancel_callback) + + @patch.object(pool.SimpleTaskPool, '_start_one') + async def test_start(self, mock__start_one: AsyncMock): + mock__start_one.return_value = FOO + num = 5 + output = await self.task_pool.start(num) + expected_output = num * [FOO] + self.assertListEqual(expected_output, output) + mock__start_one.assert_has_awaits(num * [call()]) + + @patch.object(pool.SimpleTaskPool, 'cancel') + def test_stop(self, mock_cancel: MagicMock): + num = 2 + id1, id2, id3 = 5, 6, 7 + self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR} + output = self.task_pool.stop(num) + expected_output = [id3, id2] + self.assertEqual(expected_output, output) + mock_cancel.assert_called_once_with(*expected_output) + mock_cancel.reset_mock() + + num = 50 + output = self.task_pool.stop(num) + expected_output = [id3, id2, id1] + self.assertEqual(expected_output, output) + mock_cancel.assert_called_once_with(*expected_output) + + @patch.object(pool.SimpleTaskPool, 'num_running', new_callable=PropertyMock) + @patch.object(pool.SimpleTaskPool, 'stop') + def test_stop_all(self, mock_stop: MagicMock, mock_num_running: MagicMock): + mock_num_running.return_value = num = 9876 + mock_stop.return_value = expected_output = 'something' + output = self.task_pool.stop_all() + self.assertEqual(expected_output, output) + mock_num_running.assert_called_once_with() + mock_stop.assert_called_once_with(num)