Compare commits

..

3 Commits

13 changed files with 133 additions and 72 deletions

View File

@ -2,9 +2,18 @@
**Dynamically manage pools of asyncio tasks**
## Contents
- [Contents](#contents)
- [Summary](#summary)
- [Usage](#usage)
- [Installation](#installation)
- [Dependencies](#dependencies)
- [Testing](#testing)
- [License](#license)
## Summary
A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
A **task pool** is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
With an interface that is intentionally similar to the [`multiprocessing.Pool`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool) class from the standard library, the `TaskPool` provides you such methods as `apply`, `map`, and `starmap` to execute coroutines concurrently as [`asyncio.Task`](https://docs.python.org/3/library/asyncio-task.html#task-object) objects. There is no limitation imposed on what kind of tasks can be run or in what combination, when new ones can be added, or when they can be cancelled.
@ -22,7 +31,7 @@ from asyncio_taskpool import SimpleTaskPool
...
async def work(foo, bar): ...
async def work(_foo, _bar): ...
...
@ -55,7 +64,7 @@ Python Version 3.8+, tested on Linux
## Testing
Install `asyncio-taskpool[dev]` dependencies or just manually install `coverage` with `pip`.
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
## License
@ -64,6 +73,6 @@ Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests an
The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COPYING.LESSER) are included in this repository. If not, see https://www.gnu.org/licenses/.
## Copyright
---
© 2022 Daniil Fajnberg

View File

@ -1,6 +1,6 @@
[metadata]
name = asyncio-taskpool
version = 0.6.2
version = 0.6.4
author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks

View File

@ -41,7 +41,7 @@ class ControlClient(ABC):
"""
@staticmethod
def client_info() -> dict:
def _client_info() -> dict:
"""Returns a dictionary of client information relevant for the handshake with the server."""
return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}
@ -73,9 +73,10 @@ class ControlClient(ABC):
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
"""
self._connected = True
writer.write(json.dumps(self.client_info()).encode())
writer.write(json.dumps(self._client_info()).encode())
await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
print("Type '-h' to get help and usage instructions for all available commands.\n")
def _get_command(self, writer: StreamWriter) -> Optional[str]:
"""

View File

@ -23,10 +23,10 @@ from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, Help
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature
from shutil import get_terminal_size
from typing import Callable, Container, Dict, Set, Type, TypeVar
from typing import Any, Callable, Container, Dict, Set, Type, TypeVar
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..exceptions import HelpRequested
from ..exceptions import HelpRequested, ParserError
from ..helpers import get_first_doc_line
@ -35,7 +35,6 @@ ParsersDict = Dict[str, 'ControlParser']
OMIT_PARAMS_DEFAULT = ('self', )
FORMATTER_CLASS = 'formatter_class'
NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description'
@ -79,24 +78,23 @@ class ControlParser(ArgumentParser):
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None,
**kwargs) -> None:
"""
Sets additional internal attributes depending on whether a parent-parser was defined.
Subclass of the `ArgumentParser` geared towards asynchronous interaction with an object "from the outside".
The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword.
By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
Allows directing output to a specified writer rather than stdout/stderr and setting terminal width explicitly.
Args:
stream_writer:
The instance of the `asyncio.StreamWriter` to use for message output.
terminal_width (optional):
The terminal width to assume for all message formatting. Defaults to `shutil.get_terminal_size`.
The terminal width to use for all message formatting. Defaults to `shutil.get_terminal_size().columns`.
**kwargs(optional):
In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of
the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument
is empty.
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the `help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
"""
self._stream_writer: StreamWriter = stream_writer
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS))
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs)
self._flags: Set[str] = set()
@ -219,7 +217,7 @@ class ControlParser(ArgumentParser):
def error(self, message: str) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
super().error(message=message)
raise HelpRequested
raise ParserError
def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method."""
@ -267,9 +265,8 @@ class ControlParser(ArgumentParser):
# This is to be able to later unpack an arbitrary number of positional arguments.
kwargs.setdefault('nargs', '*')
if not kwargs.get('action') == 'store_true':
# The lambda wrapper around the type annotation is to avoid ValueError being raised on suppressed arguments.
# See: https://bugs.python.org/issue36078
kwargs.setdefault('type', get_arg_type_wrapper(parameter.annotation))
# Set the type from the parameter annotation.
kwargs.setdefault('type', _get_arg_type_wrapper(parameter.annotation))
return self.add_argument(*name_or_flags, **kwargs)
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
@ -293,7 +290,13 @@ class ControlParser(ArgumentParser):
self.add_function_arg(param, help=repr(param.annotation))
def get_arg_type_wrapper(cls: Type) -> Callable:
def wrapper(arg):
return arg if arg is SUPPRESS else cls(arg)
def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
"""
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
See: https://bugs.python.org/issue36078
"""
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg)
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
wrapper.__name__ = cls.__name__
return wrapper

View File

@ -125,6 +125,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
async def serve_forever(self) -> Task:
"""
This method actually starts the server and begins listening to client connections on the specified interface.
It should never block because the serving will be performed in a separate task.
"""
log.debug("Starting %s...", self.__class__.__name__)
@ -136,7 +137,7 @@ class TCPControlServer(ControlServer):
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
_client_class = TCPControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
self._host = server_kwargs.pop('host')
self._port = server_kwargs.pop('port')
super().__init__(pool, **server_kwargs)
@ -154,7 +155,7 @@ class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
from asyncio.streams import start_unix_server
self._start_unix_server = start_unix_server
self._socket_path = Path(server_kwargs.pop('path'))

View File

@ -27,7 +27,7 @@ from inspect import isfunction, signature
from typing import Callable, Optional, Union, TYPE_CHECKING
from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from ..exceptions import CommandError, HelpRequested
from ..exceptions import CommandError, HelpRequested, ParserError
from ..helpers import return_or_exception
from ..pool import TaskPool, SimpleTaskPool
from .parser import ControlParser
@ -85,7 +85,7 @@ class ControlSession:
Must correspond to the arguments expected by the `method`.
Correctly unpacks arbitrary-length positional and keyword-arguments.
"""
log.warning("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
log.debug("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__)
normal_pos, var_pos = [], []
for param in signature(method).parameters.values():
if param.name == 'self':
@ -112,11 +112,11 @@ class ControlSession:
executed and the response written to the stream will be its return value (as an encoded string).
"""
if kwargs:
log.warning("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK)
else:
log.warning("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
async def client_handshake(self) -> None:
@ -154,9 +154,11 @@ class ControlSession:
try:
kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode())
return
except HelpRequested:
except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name)
return
command = kwargs.pop(CMD)
if isfunction(command):

View File

@ -67,5 +67,9 @@ class HelpRequested(ServerException):
pass
class ParserError(ServerException):
pass
class CommandError(ServerException):
pass

View File

@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Miscellaneous helper functions.
Miscellaneous helper functions. None of these should be considered part of the public API.
"""

View File

@ -53,6 +53,6 @@ class Queue(_Queue):
Implements an asynchronous context manager for the queue.
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want,
but in some situations it makes the codes much cleaner.
but in some situations it makes the code much cleaner.
"""
self.item_processed()

View File

@ -25,7 +25,7 @@ import shutil
import sys
from pathlib import Path
from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, call, patch
from asyncio_taskpool.control import client
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES
@ -55,7 +55,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
def test_client_info(self):
self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns},
client.ControlClient.client_info())
client.ControlClient._client_info())
async def test_abstract(self):
with self.assertRaises(NotImplementedError):
@ -65,16 +65,19 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.kwargs, self.client._conn_kwargs)
self.assertFalse(self.client._connected)
@patch.object(client.ControlClient, 'client_info')
async def test__server_handshake(self, mock_client_info: MagicMock):
mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
@patch.object(client.ControlClient, '_client_info')
async def test__server_handshake(self, mock__client_info: MagicMock):
mock__client_info.return_value = mock_info = {FOO: 1, BAR: 9999}
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected)
mock_client_info.assert_called_once_with()
mock__client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode())
self.mock_print.assert_has_calls([
call("Connected to", self.mock_read.return_value.decode()),
call("Type '-h' to get help and usage instructions for all available commands.\n")
])
@patch.object(client, 'input')
def test__get_command(self, mock_input: MagicMock):

View File

@ -25,7 +25,7 @@ from unittest import TestCase
from unittest.mock import MagicMock, call, patch
from asyncio_taskpool.control import parser
from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.exceptions import HelpRequested, ParserError
FOO, BAR = 'foo', 'bar'
@ -41,7 +41,7 @@ class ControlServerTestCase(TestCase):
self.kwargs = {
'stream_writer': self.stream_writer,
'terminal_width': self.terminal_width,
parser.FORMATTER_CLASS: FOO
'formatter_class': FOO
}
self.parser = parser.ControlParser(**self.kwargs)
@ -183,7 +183,7 @@ class ControlServerTestCase(TestCase):
@patch.object(parser.ArgumentParser, 'error')
def test_error(self, mock_supercls_error: MagicMock):
with self.assertRaises(HelpRequested):
with self.assertRaises(ParserError):
self.parser.error(FOO + BAR)
mock_supercls_error.assert_called_once_with(message=FOO + BAR)
@ -194,11 +194,11 @@ class ControlServerTestCase(TestCase):
self.parser.print_help(arg)
mock_print_help.assert_called_once_with(arg)
@patch.object(parser, 'get_arg_type_wrapper')
@patch.object(parser, '_get_arg_type_wrapper')
@patch.object(parser.ArgumentParser, 'add_argument')
def test_add_function_arg(self, mock_add_argument: MagicMock, mock_get_arg_type_wrapper: MagicMock):
def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_arg_type_wrapper: MagicMock):
mock_add_argument.return_value = expected_output = 'action'
mock_get_arg_type_wrapper.return_value = mock_type = 'fake'
mock__get_arg_type_wrapper.return_value = mock_type = 'fake'
foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex
bar_default, baz_default, boo_default = 1, 0.1, 1j
@ -211,42 +211,42 @@ class ControlServerTestCase(TestCase):
kwargs = {FOO + BAR: 'xyz'}
self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs))
mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(foo_type)
mock__get_arg_type_wrapper.assert_called_once_with(foo_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs))
mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(args_type)
mock__get_arg_type_wrapper.assert_called_once_with(args_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs))
mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(bar_type)
mock__get_arg_type_wrapper.assert_called_once_with(bar_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs))
mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(baz_type)
mock__get_arg_type_wrapper.assert_called_once_with(baz_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs))
mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs)
mock_get_arg_type_wrapper.assert_called_once_with(boo_type)
mock__get_arg_type_wrapper.assert_called_once_with(boo_type)
mock_add_argument.reset_mock()
mock_get_arg_type_wrapper.reset_mock()
mock__get_arg_type_wrapper.reset_mock()
self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs))
mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs)
mock_get_arg_type_wrapper.assert_not_called()
mock__get_arg_type_wrapper.assert_not_called()
@patch.object(parser.ControlParser, 'add_function_arg')
def test_add_function_args(self, mock_add_function_arg: MagicMock):
@ -261,7 +261,8 @@ class ControlServerTestCase(TestCase):
class RestTestCase(TestCase):
def test_get_arg_type_wrapper(self):
type_wrap = parser.get_arg_type_wrapper(int)
def test__get_arg_type_wrapper(self):
type_wrap = parser._get_arg_type_wrapper(int)
self.assertEqual('int', type_wrap.__name__)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13'))

View File

@ -1,14 +1,18 @@
# Using `asyncio-taskpool`
## Contents
- [Contents](#contents)
- [Minimal example for `SimpleTaskPool`](#minimal-example-for-simpletaskpool)
- [Advanced example for `TaskPool`](#advanced-example-for-taskpool)
- [Control server example](#control-server-example)
## Minimal example for `SimpleTaskPool`
With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run.
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
```python
import logging
@ -48,7 +52,9 @@ if __name__ == '__main__':
asyncio.run(main())
```
### Output
<details>
<summary>Output: (Click to expand)</summary>
```
SimpleTaskPool-0 initialized
Started SimpleTaskPool-0_Task-0
@ -78,6 +84,7 @@ Ended SimpleTaskPool-0_Task-1
> did 4
> did 4
```
</details>
## Advanced example for `TaskPool`
@ -85,9 +92,7 @@ This time, we want to start tasks from _different_ coroutine functions **and** w
As with the simple example, we need "worker" coroutine functions that can do something asynchronously, as well as a main coroutine function that sets up the pool, starts the tasks, and eventually awaits them.
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
### Code
The following demo script enables full log output first for additional clarity. It is complete and should work as is.
```python
import logging
@ -144,10 +149,9 @@ if __name__ == '__main__':
asyncio.run(main())
```
### Output
Additional comments for the output are provided with `<---` next to the output lines.
<details>
<summary>Output: (Click to expand)</summary>
(Keep in mind that the logger and `print` asynchronously write to `stdout`.)
```
TaskPool-0 initialized
Started TaskPool-0_Task-0
@ -229,4 +233,37 @@ Ended TaskPool-0_Task-5
> Done.
```
(Added comments with `<---` next to the output lines.)
Keep in mind that the logger and `print` asynchronously write to `stdout`, so the order of lines in your output may be slightly different.
</details>
## Control server example
One of the main features of `asyncio-taskpool` is the ability to control a task pool "from the outside" at runtime.
The [example_server.py](./example_server.py) script launches a couple of worker tasks within a `SimpleTaskPool` instance and then starts a `TCPControlServer` instance for that task pool. The server is configured to locally bind to port `9999` and is stopped automatically after the "work" is done.
To run the script:
```shell
python usage/example_server.py
```
You can then connect to the server via the command line interface:
```shell
python -m asyncio_taskpool.control tcp localhost 9999
```
The CLI starts a `TCPControlClient` that connects to our example server. Once the connection is established, it gives you an input prompt allowing you to issue commands to the task pool:
```
Connected to SimpleTaskPool-0
Type '-h' to get help and usage instructions for all available commands.
>
```
It may be useful to run the server script and the client interface in two separate terminal windows side by side. The server script is configured with a verbose logger and will react to any commands issued by the client with detailed log messages in the terminal.
---
© 2022 Daniil Fajnberg

View File

@ -65,12 +65,12 @@ async def main() -> None:
# We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit.
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
await pool.start(3) # launches three worker tasks
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
# Since we don't need any "work" done anymore, we can get rid of our control server by cancelling the task.
control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks.