generated from daniil-berg/boilerplate-py
	Compare commits
	
		
			34 Commits
		
	
	
		
			v0.3.5-lw
			...
			80fc91ec47
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 80fc91ec47 | |||
| a72a7cc516 | |||
| 91d546ebc2 | |||
| 5b3ac52bf6 | |||
| 82e6ca7b1a | |||
| 153127e028 | |||
| 17539e9c27 | |||
| 1beb9fc9b0 | |||
| 23a4cb028a | |||
| 54e5bfa8a0 | |||
| 0e7e92a91b | |||
| a9011076c4 | |||
| 7e34aa106d | |||
| 4c6a5412ca | |||
| 44c03cc493 | |||
| 689a74c678 | |||
| 3503c0bf44 | |||
| 3d104c979e | |||
| a92e646411 | |||
| 3d84e1552b | |||
| 38f4ec1b06 | |||
| 6f082288d8 | |||
| 9fde231250 | |||
| c72a5035ea | |||
| eb152e4d75 | |||
| d05f84b2c3 | |||
| 7c66604ad0 | |||
| 287906a218 | |||
| ce0f9a1f65 | |||
| 5dad4ab0c7 | |||
| ae6bb1bd17 | |||
| e501a849f3 | |||
| ed6badb088 | |||
| c63f079da4 | 
| @@ -5,7 +5,6 @@ omit = | ||||
|     .venv/* | ||||
|  | ||||
| [report] | ||||
| fail_under = 100 | ||||
| show_missing = True | ||||
| skip_covered = False | ||||
| exclude_lines = | ||||
|   | ||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -3,9 +3,10 @@ | ||||
| # IDE settings: | ||||
| /.idea/ | ||||
| /.vscode/ | ||||
| # Distribution / packaging: | ||||
| # Distribution / build files: | ||||
| *.egg-info/ | ||||
| /dist/ | ||||
| /docs/build/ | ||||
| # Python cache: | ||||
| __pycache__/ | ||||
| # Testing: | ||||
|   | ||||
							
								
								
									
										28
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,9 +2,18 @@ | ||||
|  | ||||
| **Dynamically manage pools of asyncio tasks** | ||||
|  | ||||
| ## Contents | ||||
| - [Contents](#contents) | ||||
| - [Summary](#summary) | ||||
| - [Usage](#usage) | ||||
| - [Installation](#installation) | ||||
| - [Dependencies](#dependencies) | ||||
| - [Testing](#testing) | ||||
| - [License](#license) | ||||
|  | ||||
| ## Summary | ||||
|  | ||||
| A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks. | ||||
| A **task pool** is an object with a simple interface for aggregating and dynamically managing asynchronous tasks. | ||||
|  | ||||
| With an interface that is intentionally similar to the [`multiprocessing.Pool`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool) class from the standard library, the `TaskPool` provides you such methods as `apply`, `map`, and `starmap` to execute coroutines concurrently as [`asyncio.Task`](https://docs.python.org/3/library/asyncio-task.html#task-object) objects. There is no limitation imposed on what kind of tasks can be run or in what combination, when new ones can be added, or when they can be cancelled. | ||||
|  | ||||
| @@ -15,23 +24,22 @@ If you need control over a task pool at runtime, you can launch an asynchronous | ||||
| ## Usage | ||||
|  | ||||
| Generally speaking, a task is added to a pool by providing it with a coroutine function reference as well as the arguments for that function. Here is what that could look like in the most simplified form: | ||||
|  | ||||
| ```python | ||||
| from asyncio_taskpool import SimpleTaskPool | ||||
| ... | ||||
| async def work(foo, bar): ... | ||||
| ... | ||||
| async def work(_foo, _bar): ... | ||||
|  | ||||
| async def main(): | ||||
|     pool = SimpleTaskPool(work, args=('xyz', 420)) | ||||
|     await pool.start(5) | ||||
|     pool.start(5) | ||||
|     ... | ||||
|     pool.stop(3) | ||||
|     ... | ||||
|     pool.lock() | ||||
|     await pool.gather() | ||||
|     ... | ||||
|     await pool.gather_and_close() | ||||
| ``` | ||||
|  | ||||
| Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) | ||||
| Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) | ||||
|  | ||||
| For working and fully documented demo scripts see [USAGE.md](usage/USAGE.md). | ||||
|  | ||||
| @@ -47,7 +55,7 @@ Python Version 3.8+, tested on Linux | ||||
|  | ||||
| ## Testing | ||||
|  | ||||
| Install `asyncio-taskpool[dev]` dependencies or just manually install `coverage` with `pip`.  | ||||
| Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.  | ||||
| Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report. | ||||
|  | ||||
| ## License | ||||
| @@ -56,6 +64,6 @@ Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests an | ||||
|  | ||||
| The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COPYING.LESSER) are included in this repository. If not, see https://www.gnu.org/licenses/. | ||||
|  | ||||
| ## Copyright | ||||
| --- | ||||
|  | ||||
| © 2022 Daniil Fajnberg | ||||
|   | ||||
							
								
								
									
										20
									
								
								docs/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/Makefile
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| # Minimal makefile for Sphinx documentation | ||||
| # | ||||
|  | ||||
| # You can set these variables from the command line, and also | ||||
| # from the environment for the first two. | ||||
| SPHINXOPTS    ?= | ||||
| SPHINXBUILD   ?= sphinx-build | ||||
| SOURCEDIR     = source | ||||
| BUILDDIR      = build | ||||
|  | ||||
| # Put it first so that "make" without argument is like "make help". | ||||
| help: | ||||
| 	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||||
|  | ||||
| .PHONY: help Makefile | ||||
|  | ||||
| # Catch-all target: route all unknown targets to Sphinx using the new | ||||
| # "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS). | ||||
| %: Makefile | ||||
| 	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||||
							
								
								
									
										35
									
								
								docs/make.bat
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								docs/make.bat
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| @ECHO OFF | ||||
|  | ||||
| pushd %~dp0 | ||||
|  | ||||
| REM Command file for Sphinx documentation | ||||
|  | ||||
| if "%SPHINXBUILD%" == "" ( | ||||
| 	set SPHINXBUILD=sphinx-build | ||||
| ) | ||||
| set SOURCEDIR=source | ||||
| set BUILDDIR=build | ||||
|  | ||||
| if "%1" == "" goto help | ||||
|  | ||||
| %SPHINXBUILD% >NUL 2>NUL | ||||
| if errorlevel 9009 ( | ||||
| 	echo. | ||||
| 	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx | ||||
| 	echo.installed, then set the SPHINXBUILD environment variable to point | ||||
| 	echo.to the full path of the 'sphinx-build' executable. Alternatively you | ||||
| 	echo.may add the Sphinx directory to PATH. | ||||
| 	echo. | ||||
| 	echo.If you don't have Sphinx installed, grab it from | ||||
| 	echo.https://www.sphinx-doc.org/ | ||||
| 	exit /b 1 | ||||
| ) | ||||
|  | ||||
| %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% | ||||
| goto end | ||||
|  | ||||
| :help | ||||
| %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% | ||||
|  | ||||
| :end | ||||
| popd | ||||
							
								
								
									
										7
									
								
								docs/source/api/api.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/api.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| API | ||||
| === | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 4 | ||||
|  | ||||
|    asyncio_taskpool | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.client.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.client.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.control.client module | ||||
| ======================================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.control.client | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.parser.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.parser.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.control.parser module | ||||
| ======================================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.control.parser | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										18
									
								
								docs/source/api/asyncio_taskpool.control.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								docs/source/api/asyncio_taskpool.control.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| asyncio\_taskpool.control package | ||||
| ================================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.control | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
|  | ||||
| Submodules | ||||
| ---------- | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 4 | ||||
|  | ||||
|    asyncio_taskpool.control.client | ||||
|    asyncio_taskpool.control.parser | ||||
|    asyncio_taskpool.control.server | ||||
|    asyncio_taskpool.control.session | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.server.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.server.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.control.server module | ||||
| ======================================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.control.server | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.session.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.control.session.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.control.session module | ||||
| ======================================== | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.control.session | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.exceptions.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.exceptions.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.exceptions module | ||||
| =================================== | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.exceptions | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.pool.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.pool.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.pool module | ||||
| ============================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.pool | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										7
									
								
								docs/source/api/asyncio_taskpool.queue_context.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/api/asyncio_taskpool.queue_context.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| asyncio\_taskpool.queue\_context module | ||||
| ======================================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool.queue_context | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
							
								
								
									
										25
									
								
								docs/source/api/asyncio_taskpool.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								docs/source/api/asyncio_taskpool.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| asyncio\_taskpool package | ||||
| ========================= | ||||
|  | ||||
| .. automodule:: asyncio_taskpool | ||||
|    :members: | ||||
|    :undoc-members: | ||||
|    :show-inheritance: | ||||
|  | ||||
| Subpackages | ||||
| ----------- | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 4 | ||||
|  | ||||
|    asyncio_taskpool.control | ||||
|  | ||||
| Submodules | ||||
| ---------- | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 4 | ||||
|  | ||||
|    asyncio_taskpool.exceptions | ||||
|    asyncio_taskpool.pool | ||||
|    asyncio_taskpool.queue_context | ||||
							
								
								
									
										60
									
								
								docs/source/conf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								docs/source/conf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| # Configuration file for the Sphinx documentation builder. | ||||
| # | ||||
| # This file only contains a selection of the most common options. For a full | ||||
| # list see the documentation: | ||||
| # https://www.sphinx-doc.org/en/master/usage/configuration.html | ||||
|  | ||||
| # -- Path setup -------------------------------------------------------------- | ||||
|  | ||||
| # If extensions (or modules to document with autodoc) are in another directory, | ||||
| # add these directories to sys.path here. If the directory is relative to the | ||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||
| # | ||||
| # import os | ||||
| # import sys | ||||
| # sys.path.insert(0, os.path.abspath('.')) | ||||
|  | ||||
|  | ||||
| # -- Project information ----------------------------------------------------- | ||||
|  | ||||
| project = 'asyncio-taskpool' | ||||
| copyright = '2022 Daniil Fajnberg' | ||||
| author = 'Daniil Fajnberg' | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| release = '1.0.0-beta' | ||||
|  | ||||
|  | ||||
| # -- General configuration --------------------------------------------------- | ||||
|  | ||||
| # Add any Sphinx extension module names here, as strings. They can be | ||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||
| # ones. | ||||
| extensions = [ | ||||
|     'sphinx.ext.duration', | ||||
|     'sphinx.ext.napoleon' | ||||
| ] | ||||
|  | ||||
| # Add any paths that contain templates here, relative to this directory. | ||||
| templates_path = ['_templates'] | ||||
|  | ||||
| # List of patterns, relative to source directory, that match files and | ||||
| # directories to ignore when looking for source files. | ||||
| # This pattern also affects html_static_path and html_extra_path. | ||||
| exclude_patterns = [] | ||||
|  | ||||
|  | ||||
| # -- Options for HTML output ------------------------------------------------- | ||||
|  | ||||
| # The theme to use for HTML and HTML Help pages.  See the documentation for | ||||
| # a list of builtin themes. | ||||
| # | ||||
| html_theme = 'sphinx_rtd_theme' | ||||
| html_theme_options = { | ||||
|     'style_external_links': True, | ||||
| } | ||||
|  | ||||
| # Add any paths that contain custom static files (such as style sheets) here, | ||||
| # relative to this directory. They are copied after the builtin static files, | ||||
| # so a file named "default.css" will overwrite the builtin "default.css". | ||||
| html_static_path = ['_static'] | ||||
							
								
								
									
										57
									
								
								docs/source/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								docs/source/index.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| .. This file is part of asyncio-taskpool. | ||||
|  | ||||
| .. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
|    version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| .. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; | ||||
|    without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | ||||
|    See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| .. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. | ||||
|    If not, see <https://www.gnu.org/licenses/>. | ||||
|  | ||||
| .. Copyright © 2022 Daniil Fajnberg | ||||
|  | ||||
|  | ||||
| Welcome to the asyncio-taskpool documentation! | ||||
| ============================================== | ||||
|  | ||||
| :code:`asyncio-taskpool` is a Python library for dynamically and conveniently managing pools of `asyncio <https://docs.python.org/3/library/asyncio.html>`_ tasks. | ||||
|  | ||||
| Purpose | ||||
| ------- | ||||
|  | ||||
| A `task <https://docs.python.org/3/library/asyncio-task.html>`_ is a very powerful tool of concurrency in the Python world. Since concurrency always implies doing more than one thing a time, you rarely deal with just one :code:`Task` instance. However, managing multiple tasks can become a bit cumbersome quickly, as their number increases. Moreover, especially in long-running code, you may find it useful (or even necessary) to dynamically adjust the extent to which the work is distributed, i.e. increase or decrease the number of tasks. | ||||
|  | ||||
| With that in mind, this library aims to provide two things: | ||||
|  | ||||
| #. An additional layer of abstraction and convenience for managing multiple tasks. | ||||
| #. A simple interface for dynamically adding and removing tasks when a program is already running. | ||||
|  | ||||
| The first is achieved through the concept of a :doc:`task pool <pages/pool>`. The second is achieved by adding a :doc:`control server <pages/control>` to the task pool. | ||||
|  | ||||
| Installation | ||||
| ------------ | ||||
|  | ||||
| .. code-block:: bash | ||||
|  | ||||
|    $ pip install asyncio-taskpool | ||||
|  | ||||
|  | ||||
| Contents | ||||
| -------- | ||||
|  | ||||
| .. toctree:: | ||||
|    :maxdepth: 2 | ||||
|  | ||||
|    pages/pool | ||||
|    pages/control | ||||
|    api/api | ||||
|  | ||||
|  | ||||
| Indices and tables | ||||
| ------------------ | ||||
|  | ||||
| * :ref:`genindex` | ||||
| * :ref:`modindex` | ||||
| * :ref:`search` | ||||
							
								
								
									
										107
									
								
								docs/source/pages/control.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								docs/source/pages/control.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,107 @@ | ||||
| .. This file is part of asyncio-taskpool. | ||||
|  | ||||
| .. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
|    version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| .. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; | ||||
|    without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | ||||
|    See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| .. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. | ||||
|    If not, see <https://www.gnu.org/licenses/>. | ||||
|  | ||||
| .. Copyright © 2022 Daniil Fajnberg | ||||
|  | ||||
| Control interface | ||||
| ================= | ||||
|  | ||||
| When you are dealing with programs that run for a long period of time or even as daemons (i.e. indefinitely), having a way to adjust their behavior without needing to stop and restart them can be desirable. | ||||
|  | ||||
| Task pools offer a high degree of flexibility regarding the number and kind of tasks that run within them, by providing methods to easily start and stop tasks and task groups. But without additional tools, they only allow you to establish a control logic *a priori*, as demonstrated in :ref:`this code snippet <simple-control-logic>`. | ||||
|  | ||||
| What if you have a long-running program that executes certain tasks concurrently, but you don't know in advance how many of them you'll need? What if you want to be able to adjust the number of tasks manually **without stopping the task pool**? | ||||
|  | ||||
| The control server | ||||
| ------------------ | ||||
|  | ||||
| The :code:`asyncio-taskpool` library comes with a simple control interface for managing task pools that are already running, at the heart of which is the :py:class:`ControlServer <asyncio_taskpool.control.server.ControlServer>`. Any task pool can be passed to a control server. Once the server is running, you can issue commands to it either via TCP or via UNIX socket. The commands map directly to the task pool methods. | ||||
|  | ||||
| To enable control over a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` via local TCP port :code:`8001`, all you need to do is this: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|    :name: control-server-minimal | ||||
|  | ||||
|    from asyncio_taskpool import SimpleTaskPool | ||||
|    from asyncio_taskpool.control import TCPControlServer | ||||
|    from .work import any_worker_func | ||||
|  | ||||
|  | ||||
|    async def main(): | ||||
|        ... | ||||
|        pool = SimpleTaskPool(any_worker_func, kwargs={'foo': 42, 'bar': some_object}) | ||||
|        control = await TCPControlServer(pool, host='127.0.0.1', port=8001).serve_forever() | ||||
|        await control | ||||
|  | ||||
| Under the hood, the :py:class:`ControlServer <asyncio_taskpool.control.server.ControlServer>` simply uses :code:`asyncio.start_server` for instantiating a socket server. The resulting control task will run indefinitely. Cancelling the control task stops the server. | ||||
|  | ||||
| In reality, you would probably want some graceful handler for an interrupt signal that cancels any remaining tasks as well as the serving control task. | ||||
|  | ||||
| The control client | ||||
| ------------------ | ||||
|  | ||||
| Technically, any process that can read from and write to the socket exposed by the control server, will be able to interact with it. The :code:`asyncio-taskpool` package has its own simple implementation in the form of the :py:class:`ControlClient <asyncio_taskpool.control.client.ControlClient>` that makes it easy to use out of the box. | ||||
|  | ||||
| To start a client, you can use the main script of the :py:mod:`asyncio_taskpool.control` sub-package like this: | ||||
|  | ||||
| .. code-block:: bash | ||||
|  | ||||
|    $ python -m asyncio_taskpool.control tcp localhost 8001 | ||||
|  | ||||
| This would establish a connection to the control server from the previous example. Calling | ||||
|  | ||||
| .. code-block:: bash | ||||
|  | ||||
|    $ python -m asyncio_taskpool.control -h | ||||
|  | ||||
| will display the available client options. | ||||
|  | ||||
| The control session | ||||
| ------------------- | ||||
|  | ||||
| Assuming you connected successfully, you should be greeted by the server with a help message and dropped into a simple input prompt. | ||||
|  | ||||
| .. code-block:: none | ||||
|  | ||||
|    Connected to SimpleTaskPool-0 | ||||
|    Type '-h' to get help and usage instructions for all available commands. | ||||
|  | ||||
|    > | ||||
|  | ||||
| The input sent to the server is handled by a typical argument parser, so the interface should be straight-forward. A command like | ||||
|  | ||||
| .. code-block:: none | ||||
|  | ||||
|    > start 5 | ||||
|  | ||||
| will call the :py:meth:`.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method with :code:`5` as an argument and thus start 5 new tasks in the pool, while the command | ||||
|  | ||||
| .. code-block:: none | ||||
|  | ||||
|    > pool-size | ||||
|  | ||||
| will call the :py:meth:`.pool_size <asyncio_taskpool.pool.BaseTaskPool.pool_size>` property getter and return the maximum number of tasks you that can run in the pool. | ||||
|  | ||||
| When you are dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, starting new tasks works just fine, as long as the coroutine functions you want to use can be imported into the namespace of the pool. If you have a function named :code:`worker` in the module :code:`mymodule` under the package :code:`mypackage` and want to use it in a :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` call with the arguments :code:`'x'`, :code:`'x'`, and :code:`'z'`, you would do it like this: | ||||
|  | ||||
| .. code-block:: none | ||||
|  | ||||
|    > map mypackage.mymodule.worker ['x','y','z'] -n 3 | ||||
|  | ||||
| The :code:`-n` is a shorthand for :code:`--num-concurrent` in this case. In general, all (public) pool methods will have a corresponding command in the control session. | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|    The :code:`ast.literal_eval` function from the `standard library <https://docs.python.org/3/library/ast.html#ast.literal_eval>`_ is used to safely evaluate the iterable of arguments to work on. For obvious reasons, being able to provide arbitrary python objects in such a control session is neither practical nor secure. The way this is implemented now is limited in that regard, since you can only use Python literals and containers as arguments for your coroutine functions. | ||||
|  | ||||
| To exit a control session, use the :code:`exit` command or simply press :code:`Ctrl + D`. | ||||
							
								
								
									
										233
									
								
								docs/source/pages/pool.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								docs/source/pages/pool.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,233 @@ | ||||
| .. This file is part of asyncio-taskpool. | ||||
|  | ||||
| .. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
|    version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| .. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; | ||||
|    without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | ||||
|    See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| .. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. | ||||
|    If not, see <https://www.gnu.org/licenses/>. | ||||
|  | ||||
| .. Copyright © 2022 Daniil Fajnberg | ||||
|  | ||||
|  | ||||
| Task pools | ||||
| ========== | ||||
|  | ||||
| What is a task pool? | ||||
| -------------------- | ||||
|  | ||||
| A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks. | ||||
|  | ||||
| To make use of task pools, your code obviously needs to contain coroutine functions (introduced with the :code:`async def` keywords). By adding such functions along with their arguments to a task pool, they are turned into tasks and executed asynchronously. | ||||
|  | ||||
| If you are familiar with the :code:`Pool` class of the `multiprocessing module <https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool>`_ from the standard library, then you should feel at home with the :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` class. Obviously, there are major conceptual and functional differences between the two, but the methods provided by the :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` follow a very similar logic. If you never worked with process or thread pools, don't worry. Task pools are much simpler. | ||||
|  | ||||
| The :code:`TaskPool` class | ||||
| -------------------------- | ||||
|  | ||||
| There are essentially two distinct use cases for a concurrency pool. You want to | ||||
|  | ||||
| #. execute a function *n* times with the same arguments concurrently or | ||||
| #. execute a function *n* times with different arguments concurrently. | ||||
|  | ||||
| The first is accomplished with the :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` method, while the second is accomplished with the :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` method and its variations :py:meth:`.starmap() <asyncio_taskpool.pool.TaskPool.starmap>` and :py:meth:`.doublestarmap() <asyncio_taskpool.pool.TaskPool.doublestarmap>`. | ||||
|  | ||||
| Let's take a look at an example. Say you have a coroutine function that takes two queues as arguments: The first one being an input-queue (containing items to work on) and the second one being the output queue (for passing on the results to some other function). Your function may look something like this: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: work.py | ||||
|    :name: queue-worker-function | ||||
|  | ||||
|    from asyncio.queues import Queue | ||||
|  | ||||
|    async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None: | ||||
|        while True: | ||||
|            item = await in_queue.get() | ||||
|            ... # Do some work on the item and arrive at a result. | ||||
|            await out_queue.put(result) | ||||
|  | ||||
| How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|  | ||||
|    from asyncio.tasks import create_task, gather | ||||
|    from .work import queue_worker_function | ||||
|  | ||||
|    ... | ||||
|    # We assume that the queues have been initialized already. | ||||
|    tasks = [] | ||||
|    for _ in range(5): | ||||
|        new_task = create_task(queue_worker_function(q_in, q_out)) | ||||
|        tasks.append(new_task) | ||||
|    # Run some other code and let the tasks do their thing. | ||||
|    ... | ||||
|    # At some point, we want the tasks to stop waiting for new items and end. | ||||
|    for task in tasks: | ||||
|        task.cancel() | ||||
|    ... | ||||
|    await gather(*tasks) | ||||
|  | ||||
| By contrast, here is how you would do it with a task pool: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|  | ||||
|    from asyncio_taskpool import TaskPool | ||||
|    from .work import queue_worker_function | ||||
|  | ||||
|    ... | ||||
|    pool = TaskPool() | ||||
|    group_name = pool.apply(queue_worker_function, args=(q_in, q_out), num=5) | ||||
|    ... | ||||
|    pool.cancel_group(group_name) | ||||
|    ... | ||||
|    await pool.flush() | ||||
|  | ||||
| Pretty much self-explanatory, no? | ||||
|  | ||||
| Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: work.py | ||||
|    :name: another-worker-function | ||||
|  | ||||
|    from .my_database_stuff import insert_into_results_table | ||||
|  | ||||
|    async def another_worker_function(data: object) -> None: | ||||
|        if data.some_attribute > 1: | ||||
|            ... | ||||
|        # Do the work, arrive at results. | ||||
|        await insert_into_results_table(results) | ||||
|  | ||||
| Say we have some *iterator* of data-items (of arbitrary length) that we want to be worked on, and say we want 5 coroutines concurrently working on that data. Here is a very naive task-based solution: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|  | ||||
|    from asyncio.tasks import create_task, gather | ||||
|    from .work import another_worker_function | ||||
|  | ||||
|    async def main(): | ||||
|        ... | ||||
|        # We got our data_iterator from somewhere. | ||||
|        keep_going = True | ||||
|        while keep_going: | ||||
|            tasks = [] | ||||
|            for _ in range(5): | ||||
|                try: | ||||
|                    data = next(data_iterator) | ||||
|                except StopIteration: | ||||
|                    keep_going = False | ||||
|                    break | ||||
|                new_task = create_task(another_worker_function(data)) | ||||
|                tasks.append(new_task) | ||||
|            await gather(*tasks) | ||||
|  | ||||
| Here we already run into problems with the task-based approach. The last line in our :code:`while`-loop blocks until **all 5 tasks** return (or raise an exception). This means that as soon as one of them returns, the number of working coroutines is already less than 5 (until all the others return). This can obviously be solved in different ways. We could, for instance, wrap the creation of new tasks itself in a coroutine, which immediately creates a new task, when one is finished, and then call that coroutine 5 times concurrently. Or we could use the queue-based approach from before, but then we would need to write some queue producing coroutine. | ||||
|  | ||||
| Or we could use a task pool: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|  | ||||
|    from asyncio_taskpool import TaskPool | ||||
|    from .work import another_worker_function | ||||
|  | ||||
|  | ||||
|    async def main(): | ||||
|        ... | ||||
|        pool = TaskPool() | ||||
|        pool.map(another_worker_function, data_iterator, num_concurrent=5) | ||||
|        ... | ||||
|        await pool.gather_and_close() | ||||
|  | ||||
| Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always**  -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction). | ||||
|  | ||||
| The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`) | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|    Neither :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` nor :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` return coroutines. When they are called, the task pool immediately begins scheduling new tasks to run. No :code:`await` needed. | ||||
|  | ||||
| It can't get any simpler than that, can it? So glad you asked... | ||||
|  | ||||
| The :code:`SimpleTaskPool` class | ||||
| -------------------------------- | ||||
|  | ||||
| Let's take the :ref:`queue worker example <queue-worker-function>` from before. If we know that the task pool will only ever work with that one function with the same queue objects, we can make use of the :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` class: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|  | ||||
|    from asyncio_taskpool import SimpleTaskPool | ||||
|    from .work import queue_worker_function | ||||
|  | ||||
|  | ||||
|    async def main(): | ||||
|        ... | ||||
|        pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out)) | ||||
|        pool.start(5) | ||||
|        ... | ||||
|        pool.stop_all() | ||||
|        ... | ||||
|        await pool.gather_and_close() | ||||
|  | ||||
| This may, at first glance, not seem like much of a difference, aside from different method names. However, assume that our main function runs a loop and needs to be able to periodically regulate the number of tasks being executed in the pool based on some additional variables it receives. With the :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>`, this could not be simpler: | ||||
|  | ||||
| .. code-block:: python | ||||
|    :caption: main.py | ||||
|    :name: simple-control-logic | ||||
|  | ||||
|    from asyncio_taskpool import SimpleTaskPool | ||||
|    from .work import queue_worker_function | ||||
|  | ||||
|  | ||||
|    async def main(): | ||||
|        ... | ||||
|        pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out)) | ||||
|        await pool.start(5) | ||||
|        while True: | ||||
|           ... | ||||
|           if some_condition and pool.num_running > 10: | ||||
|               pool.stop(3) | ||||
|           elif some_other_condition and pool.num_running < 5: | ||||
|               pool.start(5) | ||||
|           else: | ||||
|               pool.start(1) | ||||
|           ... | ||||
|        await pool.gather_and_close() | ||||
|  | ||||
| Notice how we only specify the function and its arguments during initialization of the pool. From that point on, all we need is the :py:meth:`.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` add :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` methods to adjust the number of concurrently running tasks. | ||||
|  | ||||
| The trade-off here is that this simplified task pool class lacks the flexibility of the regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` class. On an instance of the latter we can call :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` and :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` as often as we like with completely unrelated functions and arguments. With a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>`, once you initialize it, it is pegged to one function and one set of arguments, and all you can do is control the number of tasks working with those. | ||||
|  | ||||
| This simplified interface becomes particularly useful in conjunction with the :doc:`control server <./control>`. | ||||
|  | ||||
| .. _blocking-pool-methods: | ||||
|  | ||||
| (Non-)Blocking pool methods | ||||
| --------------------------- | ||||
|  | ||||
| One of the main concerns when dealing with concurrent programs in general and with :code:`async` functions in particular is when and how a particular piece of code **blocks** during execution, i.e. delays the execution of the following code significantly. | ||||
|  | ||||
| .. note:: | ||||
|  | ||||
|    Every statement will block to *some* extent. Obviously, when a program does something, that takes time. This is why the proper question to ask is not *if* but *to what extent, under which circumstances* the execution of a particular line of code blocks. | ||||
|  | ||||
| It is fair to assume that anyone reading this is familiar enough with the concepts of asynchronous programming in Python to know that just slapping :code:`async` in front of a function definition will not magically make it suitable for concurrent execution (in any meaningful way). Therefore, we assume that you are dealing with coroutines that can actually unblock the `event loop  <https://docs.python.org/3/library/asyncio-eventloop.html>`_ (e.g. doing a significant amount of I/O). | ||||
|  | ||||
| So how does the task pool behave in that regard? | ||||
|  | ||||
| The only method of a pool that one should **always** assume to be blocking is :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>`. This method awaits **all** tasks in the pool, meaning as long as one of them is still running, this coroutine will not return. | ||||
|  | ||||
| .. warning:: | ||||
|  | ||||
|    This includes awaiting any callbacks that were passed along with the tasks. | ||||
|  | ||||
| One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations. | ||||
|  | ||||
| All methods that add tasks to a pool, i.e. :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants), :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>`, are non-blocking by design. They all make use of "meta tasks" under the hood and return immediately. It is important however, to realize that just because they return, does not mean that any actual tasks have been spawned. For example, if a pool size limit was set and there was "no more room" in the pool when :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` was called, there is **no guarantee** that even a single task has started, when it returns. | ||||
| @@ -1,2 +1,4 @@ | ||||
| -r common.txt | ||||
| coverage | ||||
| sphinx | ||||
| sphinx-rtd-theme | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| [metadata] | ||||
| name = asyncio-taskpool | ||||
| version = 0.3.5 | ||||
| version = 1.0.0-beta | ||||
| author = Daniil Fajnberg | ||||
| author_email = mail@daniil.fajnberg.de | ||||
| description = Dynamically manage pools of asyncio tasks | ||||
| @@ -9,9 +9,9 @@ long_description_content_type = text/markdown | ||||
| keywords = asyncio, concurrency, tasks, coroutines, asynchronous, server | ||||
| url = https://git.fajnberg.de/daniil/asyncio-taskpool | ||||
| project_urls = | ||||
|     Bug Tracker = https://git.fajnberg.de/daniil/asyncio-taskpool/issues | ||||
|     Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues | ||||
| classifiers = | ||||
|     Development Status :: 3 - Alpha | ||||
|     Development Status :: 4 - Beta | ||||
|     Programming Language :: Python :: 3 | ||||
|     Operating System :: OS Independent | ||||
|     License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) | ||||
| @@ -30,6 +30,8 @@ python_requires = >=3.8 | ||||
| [options.extras_require] | ||||
| dev = | ||||
|     coverage | ||||
|     sphinx | ||||
|     sphinx-rtd-theme | ||||
|  | ||||
| [options.packages.find] | ||||
| where = src | ||||
|   | ||||
| @@ -14,10 +14,5 @@ See the GNU Lesser General Public License for more details. | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Brings the main classes up to package level for import convenience. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from .pool import TaskPool, SimpleTaskPool | ||||
| from .server import UnixControlServer | ||||
|   | ||||
| @@ -1,67 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| CLI client entry point. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import sys | ||||
| from argparse import ArgumentParser | ||||
| from asyncio import run | ||||
| from pathlib import Path | ||||
| from typing import Dict, Any | ||||
|  | ||||
| from .client import ControlClient, UnixControlClient | ||||
| from .constants import PACKAGE_NAME | ||||
| from .pool import TaskPool | ||||
| from .server import ControlServer | ||||
|  | ||||
|  | ||||
| CONN_TYPE = 'conn_type' | ||||
| UNIX, TCP = 'unix', 'tcp' | ||||
| SOCKET_PATH = 'path' | ||||
|  | ||||
|  | ||||
| def parse_cli() -> Dict[str, Any]: | ||||
|     parser = ArgumentParser( | ||||
|         prog=PACKAGE_NAME, | ||||
|         description=f"CLI based {ControlClient.__name__} for {PACKAGE_NAME}" | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(title="Connection types", dest=CONN_TYPE) | ||||
|     unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket") | ||||
|     unix_parser.add_argument( | ||||
|         SOCKET_PATH, | ||||
|         type=Path, | ||||
|         help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening." | ||||
|     ) | ||||
|     return vars(parser.parse_args()) | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     kwargs = parse_cli() | ||||
|     if kwargs[CONN_TYPE] == UNIX: | ||||
|         client = UnixControlClient(path=kwargs[SOCKET_PATH]) | ||||
|     elif kwargs[CONN_TYPE] == TCP: | ||||
|         # TODO: Implement the TCP client class | ||||
|         client = UnixControlClient(path=kwargs[SOCKET_PATH]) | ||||
|     else: | ||||
|         print("Invalid connection type", file=sys.stderr) | ||||
|         sys.exit(2) | ||||
|     await client.start() | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run(main()) | ||||
							
								
								
									
										2
									
								
								src/asyncio_taskpool/control/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								src/asyncio_taskpool/control/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| from .server import TCPControlServer, UnixControlServer | ||||
| from .client import TCPControlClient, UnixControlClient | ||||
							
								
								
									
										80
									
								
								src/asyncio_taskpool/control/__main__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								src/asyncio_taskpool/control/__main__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| CLI entry point script for a :class:`ControlClient`. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from argparse import ArgumentParser | ||||
| from asyncio import run | ||||
| from pathlib import Path | ||||
| from typing import Any, Dict, Sequence | ||||
|  | ||||
| from ..internals.constants import PACKAGE_NAME | ||||
| from ..pool import TaskPool | ||||
| from .client import TCPControlClient, UnixControlClient | ||||
| from .server import TCPControlServer, UnixControlServer | ||||
|  | ||||
|  | ||||
| __all__ = [] | ||||
|  | ||||
|  | ||||
| CLIENT_CLASS = 'client_class' | ||||
| UNIX, TCP = 'unix', 'tcp' | ||||
| SOCKET_PATH = 'path' | ||||
| HOST, PORT = 'host', 'port' | ||||
|  | ||||
|  | ||||
| def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]: | ||||
|     parser = ArgumentParser( | ||||
|         prog=f'{PACKAGE_NAME}.control', | ||||
|         description=f"Simple CLI based control client for {PACKAGE_NAME}" | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(title="Connection types") | ||||
|  | ||||
|     tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket") | ||||
|     tcp_parser.add_argument( | ||||
|         HOST, | ||||
|         help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." | ||||
|     ) | ||||
|     tcp_parser.add_argument( | ||||
|         PORT, | ||||
|         type=int, | ||||
|         help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on." | ||||
|     ) | ||||
|     tcp_parser.set_defaults(**{CLIENT_CLASS: TCPControlClient}) | ||||
|  | ||||
|     unix_parser = subparsers.add_parser(UNIX, help="Connect via unix socket") | ||||
|     unix_parser.add_argument( | ||||
|         SOCKET_PATH, | ||||
|         type=Path, | ||||
|         help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is " | ||||
|              f"listening." | ||||
|     ) | ||||
|     unix_parser.set_defaults(**{CLIENT_CLASS: UnixControlClient}) | ||||
|  | ||||
|     return vars(parser.parse_args(args)) | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     kwargs = parse_cli() | ||||
|     client_cls = kwargs.pop(CLIENT_CLASS) | ||||
|     await client_cls(**kwargs).start() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run(main()) | ||||
| @@ -23,17 +23,28 @@ import json | ||||
| import shutil | ||||
| import sys | ||||
| from abc import ABC, abstractmethod | ||||
| from asyncio.streams import StreamReader, StreamWriter, open_unix_connection | ||||
| from asyncio.streams import StreamReader, StreamWriter, open_connection | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
| from typing import Optional, Union | ||||
| 
 | ||||
| from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES | ||||
| from .types import ClientConnT, PathT | ||||
| from ..internals.constants import CLIENT_INFO, SESSION_MSG_BYTES | ||||
| from ..internals.types import ClientConnT, PathT | ||||
| 
 | ||||
| 
 | ||||
| __all__ = [ | ||||
|     'ControlClient', | ||||
|     'TCPControlClient', | ||||
|     'UnixControlClient', | ||||
|     'CLIENT_EXIT' | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| CLIENT_EXIT = 'exit' | ||||
| 
 | ||||
| 
 | ||||
| class ControlClient(ABC): | ||||
|     """ | ||||
|     Abstract base class for a simple implementation of a task pool control client. | ||||
|     Abstract base class for a simple implementation of a pool control client. | ||||
| 
 | ||||
|     Since the server's control interface is simply expecting commands to be sent, any process able to connect to the | ||||
|     TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well. | ||||
| @@ -41,7 +52,7 @@ class ControlClient(ABC): | ||||
|     """ | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def client_info() -> dict: | ||||
|     def _client_info() -> dict: | ||||
|         """Returns a dictionary of client information relevant for the handshake with the server.""" | ||||
|         return {CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns} | ||||
| 
 | ||||
| @@ -50,15 +61,15 @@ class ControlClient(ABC): | ||||
|         """ | ||||
|         Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair. | ||||
| 
 | ||||
|         This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked) | ||||
|         as keyword-arguments. | ||||
|         This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` | ||||
|         (unpacked) as keyword-arguments. | ||||
|         This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of | ||||
|         `None` and `None`, if it failed to establish the defined connection. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def __init__(self, **conn_kwargs) -> None: | ||||
|         """Simply stores the connection keyword-arguments necessary for opening the connection.""" | ||||
|         """Simply stores the keyword-arguments for opening the connection.""" | ||||
|         self._conn_kwargs = conn_kwargs | ||||
|         self._connected: bool = False | ||||
| 
 | ||||
| @@ -73,9 +84,10 @@ class ControlClient(ABC): | ||||
|             writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method | ||||
|         """ | ||||
|         self._connected = True | ||||
|         writer.write(json.dumps(self.client_info()).encode()) | ||||
|         writer.write(json.dumps(self._client_info()).encode()) | ||||
|         await writer.drain() | ||||
|         print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) | ||||
|         print("Type '-h' to get help and usage instructions for all available commands.\n") | ||||
| 
 | ||||
|     def _get_command(self, writer: StreamWriter) -> Optional[str]: | ||||
|         """ | ||||
| @@ -90,7 +102,7 @@ class ControlClient(ABC): | ||||
|         """ | ||||
|         try: | ||||
|             msg = input("> ").strip().lower() | ||||
|         except EOFError:  # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command. | ||||
|         except EOFError:  # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command. | ||||
|             msg = CLIENT_EXIT | ||||
|         except KeyboardInterrupt:  # Ctrl+C shall simply reset to the input prompt. | ||||
|             print() | ||||
| @@ -128,11 +140,14 @@ class ControlClient(ABC): | ||||
| 
 | ||||
|     async def start(self) -> None: | ||||
|         """ | ||||
|         This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop. | ||||
|         Opens connection, performs handshake, and enters interaction loop. | ||||
| 
 | ||||
|         An input prompt is presented to the user and any input is sent (encoded) to the connected server. | ||||
|         One exception is the :const:`CLIENT_EXIT` command (equivalent to Ctrl+D), which merely closes the connection. | ||||
| 
 | ||||
|         If the connection can not be established, an error message is printed to `stderr` and the method returns. | ||||
|         If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a | ||||
|         disconnected-message. | ||||
|         If either the exit command is issued or the connection to the server is lost during the interaction loop, | ||||
|         the method returns and prints out a disconnected-message. | ||||
|         """ | ||||
|         reader, writer = await self._open_connection(**self._conn_kwargs) | ||||
|         if reader is None: | ||||
| @@ -144,15 +159,36 @@ class ControlClient(ABC): | ||||
|         print("Disconnected from control server.") | ||||
| 
 | ||||
| 
 | ||||
| class TCPControlClient(ControlClient): | ||||
|     """Task pool control client for connecting to a :class:`TCPControlServer`.""" | ||||
| 
 | ||||
|     def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None: | ||||
|         """`host` and `port` are expected as non-optional connection arguments.""" | ||||
|         self._host = host | ||||
|         self._port = port | ||||
|         super().__init__(**conn_kwargs) | ||||
| 
 | ||||
|     async def _open_connection(self, **kwargs) -> ClientConnT: | ||||
|         """ | ||||
|         Wrapper around the `asyncio.open_connection` function. | ||||
| 
 | ||||
|         Returns a tuple of `None` and `None`, if the connection can not be established; | ||||
|         otherwise, the stream-reader and -writer tuple is returned. | ||||
|         """ | ||||
|         try: | ||||
|             return await open_connection(self._host, self._port, **kwargs) | ||||
|         except ConnectionError as e: | ||||
|             print(str(e), file=sys.stderr) | ||||
|             return None, None | ||||
| 
 | ||||
| 
 | ||||
| class UnixControlClient(ControlClient): | ||||
|     """Task pool control client that expects a unix socket to be exposed by the control server.""" | ||||
|     """Task pool control client for connecting to a :class:`UnixControlServer`.""" | ||||
| 
 | ||||
|     def __init__(self, socket_path: PathT, **conn_kwargs) -> None: | ||||
|         """ | ||||
|         In addition to what the base class does, the `socket_path` is expected as a non-optional argument. | ||||
| 
 | ||||
|         The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument. | ||||
|         """ | ||||
|         """`socket_path` is expected as a non-optional connection argument.""" | ||||
|         from asyncio.streams import open_unix_connection | ||||
|         self._open_unix_connection = open_unix_connection | ||||
|         self._socket_path = Path(socket_path) | ||||
|         super().__init__(**conn_kwargs) | ||||
| 
 | ||||
| @@ -164,7 +200,7 @@ class UnixControlClient(ControlClient): | ||||
|         otherwise, the stream-reader and -writer tuple is returned. | ||||
|         """ | ||||
|         try: | ||||
|             return await open_unix_connection(self._socket_path, **kwargs) | ||||
|             return await self._open_unix_connection(self._socket_path, **kwargs) | ||||
|         except FileNotFoundError: | ||||
|             print("No socket at", self._socket_path, file=sys.stderr) | ||||
|             return None, None | ||||
							
								
								
									
										342
									
								
								src/asyncio_taskpool/control/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										342
									
								
								src/asyncio_taskpool/control/parser.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,342 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Definition of the :class:`ControlParser` used in a  | ||||
| :class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import logging | ||||
| from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS | ||||
| from ast import literal_eval | ||||
| from asyncio.streams import StreamWriter | ||||
| from inspect import Parameter, getmembers, isfunction, signature | ||||
| from shutil import get_terminal_size | ||||
| from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar | ||||
|  | ||||
| from ..exceptions import HelpRequested, ParserError | ||||
| from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER | ||||
| from ..internals.helpers import get_first_doc_line, resolve_dotted_path | ||||
| from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT | ||||
|  | ||||
|  | ||||
| __all__ = ['ControlParser'] | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) | ||||
| ParsersDict = Dict[str, 'ControlParser'] | ||||
|  | ||||
| OMIT_PARAMS_DEFAULT = ('self', ) | ||||
|  | ||||
| NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description' | ||||
|  | ||||
|  | ||||
| class ControlParser(ArgumentParser): | ||||
|     """ | ||||
|     Subclass of the standard :code:`argparse.ArgumentParser` for pool control. | ||||
|  | ||||
|     Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` | ||||
|     instance passed to it during initialization. | ||||
|     Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a | ||||
|     connected client. | ||||
|     Finally, it offers some convenience methods and makes use of custom exceptions. | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: | ||||
|         """ | ||||
|         Constructs and returns a subclass of :class:`argparse.HelpFormatter` | ||||
|  | ||||
|         The formatter class will have the defined `terminal_width`. | ||||
|  | ||||
|         Although a custom formatter class can be explicitly passed into the :class:`ArgumentParser` constructor, | ||||
|         this is not as convenient, when making use of sub-parsers. | ||||
|  | ||||
|         Args: | ||||
|             terminal_width: | ||||
|                 The number of columns of the terminal to which to adjust help formatting. | ||||
|             base_cls (optional): | ||||
|                 Base class to use for inheritance. By default :class:`argparse.ArgumentDefaultsHelpFormatter` is used. | ||||
|  | ||||
|         Returns: | ||||
|             The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. | ||||
|         """ | ||||
|         if base_cls is None: | ||||
|             base_cls = ArgumentDefaultsHelpFormatter | ||||
|  | ||||
|         class ClientHelpFormatter(base_cls): | ||||
|             def __init__(self, *args, **kwargs) -> None: | ||||
|                 kwargs['width'] = terminal_width | ||||
|                 super().__init__(*args, **kwargs) | ||||
|         return ClientHelpFormatter | ||||
|  | ||||
|     def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None: | ||||
|         """ | ||||
|         Sets some internal attributes in addition to the base class. | ||||
|  | ||||
|         Args: | ||||
|             stream_writer: | ||||
|                 The instance of the :class:`asyncio.StreamWriter` to use for message output. | ||||
|             terminal_width (optional): | ||||
|                 The terminal width to use for all message formatting. By default the :code:`columns` attribute from | ||||
|                 :func:`shutil.get_terminal_size` is taken. | ||||
|             **kwargs(optional): | ||||
|                 Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a | ||||
|                 class is specified, it will always be subclassed in the :meth:`help_formatter_factory`. | ||||
|                 Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). | ||||
|         """ | ||||
|         self._stream_writer: StreamWriter = stream_writer | ||||
|         self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns | ||||
|         kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class')) | ||||
|         kwargs.setdefault('exit_on_error', False) | ||||
|         super().__init__(**kwargs) | ||||
|         self._flags: Set[str] = set() | ||||
|         self._commands = None | ||||
|  | ||||
|     def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT, | ||||
|                              **subparser_kwargs) -> 'ControlParser': | ||||
|         """ | ||||
|         Takes a function and adds a corresponding (sub-)command to the parser. | ||||
|  | ||||
|         The :meth:`add_subparsers` method must have been called prior to this. | ||||
|  | ||||
|         NOTE: Currently, only a limited spectrum of parameters can be accurately converted to parser arguments. | ||||
|         This method works correctly with any public method of the any task pool class. | ||||
|  | ||||
|         Args: | ||||
|             function: | ||||
|                 The reference to the function to be "converted" to a parser command. | ||||
|             omit_params (optional): | ||||
|                 Names of function parameters not to add as parser arguments. | ||||
|             **subparser_kwargs (optional): | ||||
|                 Passed directly to the :meth:`add_parser` method. | ||||
|  | ||||
|         Returns: | ||||
|             The subparser instance created from the function. | ||||
|         """ | ||||
|         subparser_kwargs.setdefault(NAME, function.__name__.replace('_', '-')) | ||||
|         subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME]) | ||||
|         subparser_kwargs.setdefault(HELP, get_first_doc_line(function)) | ||||
|         subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP]) | ||||
|         subparser: ControlParser = self._commands.add_parser(**subparser_kwargs) | ||||
|         subparser.add_function_args(function, omit_params) | ||||
|         return subparser | ||||
|  | ||||
|     def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser': | ||||
|         """ | ||||
|         Same as the :meth:`add_function_command` method, but for properties. | ||||
|  | ||||
|         Args: | ||||
|             prop: | ||||
|                 The reference to the property to be "converted" to a parser command. | ||||
|             cls_name (optional): | ||||
|                 Name of the class the property is defined on to appear in the command help text. | ||||
|             **subparser_kwargs (optional): | ||||
|                 Passed directly to the :meth:`add_parser` method. | ||||
|  | ||||
|         Returns: | ||||
|             The subparser instance created from the property. | ||||
|         """ | ||||
|         subparser_kwargs.setdefault(NAME, prop.fget.__name__.replace('_', '-')) | ||||
|         subparser_kwargs.setdefault(PROG, subparser_kwargs[NAME]) | ||||
|         getter_help = get_first_doc_line(prop.fget) | ||||
|         if prop.fset is None: | ||||
|             subparser_kwargs.setdefault(HELP, getter_help) | ||||
|         else: | ||||
|             subparser_kwargs.setdefault(HELP, f"Get/set the `{cls_name}.{subparser_kwargs[NAME]}` property") | ||||
|         subparser_kwargs.setdefault(DESCRIPTION, subparser_kwargs[HELP]) | ||||
|         subparser: ControlParser = self._commands.add_parser(**subparser_kwargs) | ||||
|         if prop.fset is not None: | ||||
|             _, param = signature(prop.fset).parameters.values() | ||||
|             setter_arg_help = f"If provided: {get_first_doc_line(prop.fset)} If omitted: {getter_help}" | ||||
|             subparser.add_function_arg(param, nargs='?', default=SUPPRESS, help=setter_arg_help) | ||||
|         return subparser | ||||
|  | ||||
|     def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (), | ||||
|                            member_arg_name: str = CMD) -> ParsersDict: | ||||
|         """ | ||||
|         Adds methods/properties of a class as (sub-)commands to the parser. | ||||
|  | ||||
|         The :meth:`add_subparsers` method must have been called prior to this. | ||||
|  | ||||
|         NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments. | ||||
|         This method works correctly with any task pool class. | ||||
|  | ||||
|         Args: | ||||
|             cls: | ||||
|                 The reference to the class whose methods/properties are to be "converted" to parser commands. | ||||
|             public_only (optional): | ||||
|                 If `False`, protected and private members are considered as well. `True` by default. | ||||
|             omit_members (optional): | ||||
|                 Names of functions/properties not to add as parser commands. | ||||
|             member_arg_name (optional): | ||||
|                 After parsing the arguments, depending on which command was invoked by the user, the corresponding | ||||
|                 method/property will be stored as an extra argument in the parsed namespace under this attribute name. | ||||
|  | ||||
|         Returns: | ||||
|             Dictionary mapping class member names to the (sub-)parsers created from them. | ||||
|         """ | ||||
|         parsers: ParsersDict = {} | ||||
|         common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width} | ||||
|         for name, member in getmembers(cls): | ||||
|             if name in omit_members or (name.startswith('_') and public_only): | ||||
|                 continue | ||||
|             if isfunction(member): | ||||
|                 subparser = self.add_function_command(member, **common_kwargs) | ||||
|             elif isinstance(member, property): | ||||
|                 subparser = self.add_property_command(member, cls.__name__, **common_kwargs) | ||||
|             else: | ||||
|                 continue | ||||
|             subparser.set_defaults(**{member_arg_name: member}) | ||||
|             parsers[name] = subparser | ||||
|         return parsers | ||||
|  | ||||
|     def add_subparsers(self, *args, **kwargs): | ||||
|         """Adds the subparsers action as an attribute before returning it.""" | ||||
|         self._commands = super().add_subparsers(*args, **kwargs) | ||||
|         return self._commands | ||||
|  | ||||
|     def _print_message(self, message: str, *args, **kwargs) -> None: | ||||
|         """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" | ||||
|         if message: | ||||
|             self._stream_writer.write(message.encode()) | ||||
|  | ||||
|     def exit(self, status: int = 0, message: str = None) -> None: | ||||
|         """This is overridden to prevent system exit to be invoked.""" | ||||
|         if message: | ||||
|             self._print_message(message) | ||||
|  | ||||
|     def error(self, message: str) -> None: | ||||
|         """Raises the :exc:`ParserError <asyncio_taskpool.exceptions.ParserError>` exception at the end.""" | ||||
|         super().error(message=message) | ||||
|         raise ParserError | ||||
|  | ||||
|     def print_help(self, file=None) -> None: | ||||
|         """Raises the :exc:`HelpRequested <asyncio_taskpool.exceptions.HelpRequested>` exception at the end.""" | ||||
|         super().print_help(file) | ||||
|         raise HelpRequested | ||||
|  | ||||
|     def add_function_arg(self, parameter: Parameter, **kwargs) -> Action: | ||||
|         """ | ||||
|         Takes an :class:`inspect.Parameter` and adds a corresponding parser argument. | ||||
|  | ||||
|         NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. | ||||
|         This method works correctly with any parameter of any public method any task pool class. | ||||
|  | ||||
|         Args: | ||||
|             parameter: The :class:`inspect.Parameter` object to be converted to a parser argument. | ||||
|             **kwargs: Passed to the :meth:`add_argument` method of the base class. | ||||
|  | ||||
|         Returns: | ||||
|             The :class:`argparse.Action` returned by the :meth:`add_argument` method. | ||||
|         """ | ||||
|         if parameter.default is Parameter.empty: | ||||
|             # A non-optional function parameter should correspond to a positional argument. | ||||
|             name_or_flags = [parameter.name] | ||||
|         else: | ||||
|             flag = None | ||||
|             long = f'--{parameter.name.replace("_", "-")}' | ||||
|             # We try to generate a short version (flag) for the argument. | ||||
|             letter = parameter.name[0] | ||||
|             if letter not in self._flags: | ||||
|                 flag = f'-{letter}' | ||||
|                 self._flags.add(letter) | ||||
|             elif letter.upper() not in self._flags: | ||||
|                 flag = f'-{letter.upper()}' | ||||
|                 self._flags.add(letter.upper()) | ||||
|             name_or_flags = [long] if flag is None else [flag, long] | ||||
|             if parameter.annotation is bool: | ||||
|                 # If we are dealing with a boolean parameter, always use the 'store_true' action. | ||||
|                 # Even if the parameter's default value is `True`, this will make the parser argument's default `False`. | ||||
|                 kwargs.setdefault('action', 'store_true') | ||||
|             else: | ||||
|                 # For now, any other type annotation will implicitly use the default action 'store'. | ||||
|                 # In addition, we always set the default value. | ||||
|                 kwargs.setdefault('default', parameter.default) | ||||
|         if parameter.kind == Parameter.VAR_POSITIONAL: | ||||
|             # This is to be able to later unpack an arbitrary number of positional arguments. | ||||
|             kwargs.setdefault('nargs', '*') | ||||
|         if not kwargs.get('action') == 'store_true': | ||||
|             # Set the type from the parameter annotation. | ||||
|             kwargs.setdefault('type', _get_type_from_annotation(parameter.annotation)) | ||||
|         return self.add_argument(*name_or_flags, **kwargs) | ||||
|  | ||||
|     def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None: | ||||
|         """ | ||||
|         Takes a function and adds its parameters as arguments to the parser. | ||||
|  | ||||
|         NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. | ||||
|         This method works correctly with any public method of any task pool class. | ||||
|  | ||||
|         Args: | ||||
|             function: | ||||
|                 The function whose parameters are to be converted to parser arguments. | ||||
|                 Its parameters must be properly annotated. | ||||
|             omit (optional): | ||||
|                 Names of function parameters not to add as parser arguments. | ||||
|         """ | ||||
|         for param in signature(function).parameters.values(): | ||||
|             if param.name not in omit: | ||||
|                 # TODO: Look into parsing docstrings properly to try and extract argument help text. | ||||
|                 #       For now, the argument help just shows the type it will be converted to. | ||||
|                 self.add_function_arg(param, help=repr(param.annotation)) | ||||
|  | ||||
|  | ||||
| def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]: | ||||
|     """ | ||||
|     Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments. | ||||
|  | ||||
|     See: https://bugs.python.org/issue36078 | ||||
|  | ||||
|     In addition, the type conversion wrapper catches exceptions not handled properly by the parser, logs them, and | ||||
|     turns them into `ArgumentTypeError` exceptions the parser can propagate to the client. | ||||
|     """ | ||||
|     def wrapper(arg: Any) -> Any: | ||||
|         if arg is SUPPRESS: | ||||
|             return arg | ||||
|         try: | ||||
|             return cls(arg) | ||||
|         except (ArgumentTypeError, TypeError, ValueError): | ||||
|             raise  # handled properly by the parser and propagated to the client anyway | ||||
|         except Exception as e: | ||||
|             text = f"{e.__class__.__name__} occurred in parser trying to convert type: {cls.__name__}({repr(arg)})" | ||||
|             log.exception(text) | ||||
|             raise ArgumentTypeError(text)  # propagate to the client | ||||
|     # Copy the name of the class to maintain useful help messages when incorrect arguments are passed. | ||||
|     wrapper.__name__ = cls.__name__ | ||||
|     return wrapper | ||||
|  | ||||
|  | ||||
| def _get_type_from_annotation(annotation: Type) -> Callable[[Any], Any]: | ||||
|     """ | ||||
|     Returns a type conversion function based on the `annotation` passed. | ||||
|  | ||||
|     Required to properly convert parsed arguments to the type expected by certain pool methods. | ||||
|     Each conversion function is wrapped by `_get_arg_type_wrapper`. | ||||
|  | ||||
|     `Callable`-type annotations give the `resolve_dotted_path` function. | ||||
|     `Iterable`- or args/kwargs-type annotations give the `ast.literal_eval` function. | ||||
|     Others pass unchanged (but still wrapped with `_get_arg_type_wrapper`). | ||||
|     """ | ||||
|     if any(annotation is t for t in {CoroutineFunc, EndCB, CancelCB}): | ||||
|         annotation = resolve_dotted_path | ||||
|     if any(annotation is t for t in {ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]}): | ||||
|         annotation = literal_eval | ||||
|     return _get_arg_type_wrapper(annotation) | ||||
| @@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
| 
 | ||||
| __doc__ = """ | ||||
| This module contains the task pool control server class definitions. | ||||
| Task pool control server class definitions. | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| @@ -23,35 +23,73 @@ import logging | ||||
| from abc import ABC, abstractmethod | ||||
| from asyncio import AbstractServer | ||||
| from asyncio.exceptions import CancelledError | ||||
| from asyncio.streams import StreamReader, StreamWriter, start_unix_server | ||||
| from asyncio.streams import StreamReader, StreamWriter, start_server | ||||
| from asyncio.tasks import Task, create_task | ||||
| from pathlib import Path | ||||
| from typing import Optional, Union | ||||
| 
 | ||||
| from .client import ControlClient, UnixControlClient | ||||
| from .pool import TaskPool, SimpleTaskPool | ||||
| from .client import ControlClient, TCPControlClient, UnixControlClient | ||||
| from .session import ControlSession | ||||
| from .types import ConnectedCallbackT | ||||
| from ..pool import AnyTaskPoolT | ||||
| from ..internals.types import ConnectedCallbackT, PathT | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ['ControlServer', 'TCPControlServer', 'UnixControlServer'] | ||||
| 
 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class ControlServer(ABC):  # TODO: Implement interface for normal TaskPool instances, not just SimpleTaskPool | ||||
| class ControlServer(ABC): | ||||
|     """ | ||||
|     Abstract base class for a task pool control server. | ||||
| 
 | ||||
|     This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client | ||||
|     connecting to it. The entire interface is defined within that session class. | ||||
|     This class acts as a wrapper around an async server instance and initializes a | ||||
|     :class:`ControlSession <asyncio_taskpool.control.session.ControlSession>` once a client connects to it. | ||||
|     The interface is defined within the session class. | ||||
|     """ | ||||
|     _client_class = ControlClient | ||||
| 
 | ||||
|     @classmethod | ||||
|     @property | ||||
|     def client_class_name(cls) -> str: | ||||
|         """Returns the name of the control client class matching the server class.""" | ||||
|         """Returns the name of the matching control client class.""" | ||||
|         return cls._client_class.__name__ | ||||
| 
 | ||||
|     def __init__(self, pool: AnyTaskPoolT, **server_kwargs) -> None: | ||||
|         """ | ||||
|         Merely sets internal attributes, but does not start the server yet. | ||||
|         The task pool must be passed here and can not be set/changed afterwards. This means a control server is always | ||||
|         tied to one specific task pool. | ||||
| 
 | ||||
|         Args: | ||||
|             pool: | ||||
|                 An instance of a `BaseTaskPool` subclass to tie the server to. | ||||
|             **server_kwargs (optional): | ||||
|                 Keyword arguments that will be passed into the function that starts the server. | ||||
|         """ | ||||
|         self._pool: AnyTaskPoolT = pool | ||||
|         self._server_kwargs = server_kwargs | ||||
|         self._server: Optional[AbstractServer] = None | ||||
| 
 | ||||
|     @property | ||||
|     def pool(self) -> AnyTaskPoolT: | ||||
|         """The task pool instance controlled by the server.""" | ||||
|         return self._pool | ||||
| 
 | ||||
|     def is_serving(self) -> bool: | ||||
|         """Wrapper around the `asyncio.Server.is_serving` method.""" | ||||
|         return self._server.is_serving() | ||||
| 
 | ||||
|     async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: | ||||
|         """ | ||||
|         The universal client callback that will be passed into the `_get_server_instance` method. | ||||
|         Instantiates a control session, performs the client handshake, and enters the session's `listen` loop. | ||||
|         """ | ||||
|         session = ControlSession(self, reader, writer) | ||||
|         await session.client_handshake() | ||||
|         await session.listen() | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: | ||||
|         """ | ||||
| @@ -74,40 +112,6 @@ class ControlServer(ABC):  # TODO: Implement interface for normal TaskPool insta | ||||
|         """The method to run after the server's `serve_forever` methods ends for whatever reason.""" | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: | ||||
|         """ | ||||
|         Initializes by merely saving the internal attributes, but without starting the server yet. | ||||
|         The task pool must be passed here and can not be set/changed afterwards. This means a control server is always | ||||
|         tied to one specific task pool. | ||||
| 
 | ||||
|         Args: | ||||
|             pool: | ||||
|                 An instance of a `BaseTaskPool` subclass to tie the server to. | ||||
|             **server_kwargs (optional): | ||||
|                 Keyword arguments that will be passed into the function that starts the server. | ||||
|         """ | ||||
|         self._pool: Union[TaskPool, SimpleTaskPool] = pool | ||||
|         self._server_kwargs = server_kwargs | ||||
|         self._server: Optional[AbstractServer] = None | ||||
| 
 | ||||
|     @property | ||||
|     def pool(self) -> Union[TaskPool, SimpleTaskPool]: | ||||
|         """Read-only property for accessing the task pool instance controlled by the server.""" | ||||
|         return self._pool | ||||
| 
 | ||||
|     def is_serving(self) -> bool: | ||||
|         """Wrapper around the `asyncio.Server.is_serving` method.""" | ||||
|         return self._server.is_serving() | ||||
| 
 | ||||
|     async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None: | ||||
|         """ | ||||
|         The universal client callback that will be passed into the `_get_server_instance` method. | ||||
|         Instantiates a control session, performs the client handshake, and enters the session's `listen` loop. | ||||
|         """ | ||||
|         session = ControlSession(self, reader, writer) | ||||
|         await session.client_handshake() | ||||
|         await session.listen() | ||||
| 
 | ||||
|     async def _serve_forever(self) -> None: | ||||
|         """ | ||||
|         To be run as an `asyncio.Task` by the following method. | ||||
| @@ -124,24 +128,50 @@ class ControlServer(ABC):  # TODO: Implement interface for normal TaskPool insta | ||||
| 
 | ||||
|     async def serve_forever(self) -> Task: | ||||
|         """ | ||||
|         This method actually starts the server and begins listening to client connections on the specified interface. | ||||
|         Starts the server and begins listening to client connections. | ||||
| 
 | ||||
|         It should never block because the serving will be performed in a separate task. | ||||
| 
 | ||||
|         Returns: | ||||
|             The forever serving task. To stop the server, this task should be cancelled. | ||||
|         """ | ||||
|         log.debug("Starting %s...", self.__class__.__name__) | ||||
|         self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs) | ||||
|         return create_task(self._serve_forever()) | ||||
| 
 | ||||
| 
 | ||||
| class UnixControlServer(ControlServer): | ||||
|     """Task pool control server class that exposes a unix socket for control clients to connect to.""" | ||||
|     _client_class = UnixControlClient | ||||
| class TCPControlServer(ControlServer): | ||||
|     """Exposes a TCP socket for control clients to connect to.""" | ||||
|     _client_class = TCPControlClient | ||||
| 
 | ||||
|     def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: | ||||
|         self._socket_path = Path(server_kwargs.pop('path')) | ||||
|     def __init__(self, pool: AnyTaskPoolT, host: str, port: Union[int, str], **server_kwargs) -> None: | ||||
|         """`host` and `port` are expected as non-optional server arguments.""" | ||||
|         self._host = host | ||||
|         self._port = port | ||||
|         super().__init__(pool, **server_kwargs) | ||||
| 
 | ||||
|     async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: | ||||
|         server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) | ||||
|         server = await start_server(client_connected_cb, self._host, self._port, **kwargs) | ||||
|         log.debug("Opened socket at %s:%s", self._host, self._port) | ||||
|         return server | ||||
| 
 | ||||
|     def _final_callback(self) -> None: | ||||
|         log.debug("Closed socket at %s:%s", self._host, self._port) | ||||
| 
 | ||||
| 
 | ||||
| class UnixControlServer(ControlServer): | ||||
|     """Exposes a unix socket for control clients to connect to.""" | ||||
|     _client_class = UnixControlClient | ||||
| 
 | ||||
|     def __init__(self, pool: AnyTaskPoolT, socket_path: PathT, **server_kwargs) -> None: | ||||
|         """`socket_path` is expected as a non-optional server argument.""" | ||||
|         from asyncio.streams import start_unix_server | ||||
|         self._start_unix_server = start_unix_server | ||||
|         self._socket_path = Path(socket_path) | ||||
|         super().__init__(pool, **server_kwargs) | ||||
| 
 | ||||
|     async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: | ||||
|         server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs) | ||||
|         log.debug("Opened socket '%s'", str(self._socket_path)) | ||||
|         return server | ||||
| 
 | ||||
							
								
								
									
										191
									
								
								src/asyncio_taskpool/control/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								src/asyncio_taskpool/control/session.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Definition of the :class:`ControlSession` used by a :class:`ControlServer`. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import logging | ||||
| import json | ||||
| from argparse import ArgumentError | ||||
| from asyncio.streams import StreamReader, StreamWriter | ||||
| from inspect import isfunction, signature | ||||
| from typing import Callable, Optional, Union, TYPE_CHECKING | ||||
|  | ||||
| from .parser import ControlParser | ||||
| from ..exceptions import CommandError, HelpRequested, ParserError | ||||
| from ..pool import TaskPool, SimpleTaskPool | ||||
| from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER | ||||
| from ..internals.helpers import return_or_exception | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from .server import ControlServer | ||||
|  | ||||
|  | ||||
| __all__ = ['ControlSession'] | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class ControlSession: | ||||
|     """ | ||||
|     Manages a single control session between a server and a client. | ||||
|  | ||||
|     The commands received from a connected client are translated into method calls on the task pool instance. | ||||
|     A subclass of the standard :class:`argparse.ArgumentParser` is used to handle the input read from the stream. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: | ||||
|         """ | ||||
|         Connection to the control server should already been established. | ||||
|  | ||||
|         For more convenient/efficient access, some of the server's properties are saved in separate attributes. | ||||
|         The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during | ||||
|         initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured. | ||||
|  | ||||
|         Args: | ||||
|             server: | ||||
|                 The instance of a :class:`ControlServer` subclass starting the session. | ||||
|             reader: | ||||
|                 The `asyncio.StreamReader` created when a client connected to the server. | ||||
|             writer: | ||||
|                 The `asyncio.StreamWriter` created when a client connected to the server. | ||||
|         """ | ||||
|         self._control_server: 'ControlServer' = server | ||||
|         self._pool: Union[TaskPool, SimpleTaskPool] = server.pool | ||||
|         self._client_class_name = server.client_class_name | ||||
|         self._reader: StreamReader = reader | ||||
|         self._writer: StreamWriter = writer | ||||
|         self._parser: Optional[ControlParser] = None | ||||
|  | ||||
|     async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None: | ||||
|         """ | ||||
|         Takes a pool method reference, executes it, and writes a response accordingly. | ||||
|  | ||||
|         If the first parameter is named `self`, the method will be called with the `_pool` instance as its first | ||||
|         positional argument. | ||||
|         If it returns nothing, the response upon successful execution will be :const:`constants.CMD_OK`, otherwise the | ||||
|         response written to the stream will be its return value (as an encoded string). | ||||
|  | ||||
|         Args: | ||||
|             prop: | ||||
|                 The reference to the method defined on the `_pool` instance's class. | ||||
|             **kwargs (optional): | ||||
|                 Must correspond to the arguments expected by the `method`. | ||||
|                 Correctly unpacks arbitrary-length positional and keyword-arguments. | ||||
|         """ | ||||
|         log.debug("%s calls %s.%s", self._client_class_name, self._pool.__class__.__name__, method.__name__) | ||||
|         normal_pos, var_pos = [], [] | ||||
|         for param in signature(method).parameters.values(): | ||||
|             if param.name == 'self': | ||||
|                 normal_pos.append(self._pool) | ||||
|             elif param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): | ||||
|                 normal_pos.append(kwargs.pop(param.name)) | ||||
|             elif param.kind == param.VAR_POSITIONAL: | ||||
|                 var_pos = kwargs.pop(param.name) | ||||
|         output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs) | ||||
|         self._writer.write(CMD_OK if output is None else str(output).encode()) | ||||
|  | ||||
|     async def _exec_property_and_respond(self, prop: property, **kwargs) -> None: | ||||
|         """ | ||||
|         Takes a pool property reference, executes its setter or getter, and writes a response accordingly. | ||||
|  | ||||
|         The property set/get method will always be called with the `_pool` instance as its first positional argument. | ||||
|  | ||||
|         Args: | ||||
|             prop: | ||||
|                 The reference to the property defined on the `_pool` instance's class. | ||||
|             **kwargs (optional): | ||||
|                 If not empty, the property setter is executed and the keyword arguments are passed along to it; the | ||||
|                 response upon successful execution will be :const:`constants.CMD_OK`. Otherwise the property getter is | ||||
|                 executed and the response written to the stream will be its return value (as an encoded string). | ||||
|         """ | ||||
|         if kwargs: | ||||
|             log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__) | ||||
|             await return_or_exception(prop.fset, self._pool, **kwargs) | ||||
|             self._writer.write(CMD_OK) | ||||
|         else: | ||||
|             log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__) | ||||
|             self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode()) | ||||
|  | ||||
|     async def client_handshake(self) -> None: | ||||
|         """ | ||||
|         Must be invoked before starting any other client interaction. | ||||
|  | ||||
|         Client info is retrieved, server info is sent back, and the | ||||
|         :class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up. | ||||
|         """ | ||||
|         client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) | ||||
|         log.debug("%s connected", self._client_class_name) | ||||
|         parser_kwargs = { | ||||
|             STREAM_WRITER: self._writer, | ||||
|             CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH], | ||||
|             'prog': '', | ||||
|             'usage': f'[-h] [{CMD}] ...' | ||||
|         } | ||||
|         self._parser = ControlParser(**parser_kwargs) | ||||
|         self._parser.add_subparsers(title="Commands", | ||||
|                                     metavar="(A command followed by '-h' or '--help' will show command-specific help.)") | ||||
|         self._parser.add_class_commands(self._pool.__class__) | ||||
|         self._writer.write(str(self._pool).encode()) | ||||
|         await self._writer.drain() | ||||
|  | ||||
|     async def _parse_command(self, msg: str) -> None: | ||||
|         """ | ||||
|         Takes a message from the client and attempts to parse it. | ||||
|  | ||||
|         If a parsing error occurs, it is returned to the client. If the :exc:`HelpRequested` exception was raised by the | ||||
|         :class:`ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the | ||||
|         entire dictionary of keyword-arguments returned by the :class:`ControlParser` passed into it. | ||||
|  | ||||
|         Args: | ||||
|             msg: The non-empty string read from the client stream. | ||||
|         """ | ||||
|         try: | ||||
|             kwargs = vars(self._parser.parse_args(msg.split(' '))) | ||||
|         except ArgumentError as e: | ||||
|             log.debug("%s got an ArgumentError", self._client_class_name) | ||||
|             self._writer.write(str(e).encode()) | ||||
|             return | ||||
|         except (HelpRequested, ParserError): | ||||
|             log.debug("%s received usage help", self._client_class_name) | ||||
|             return | ||||
|         command = kwargs.pop(CMD) | ||||
|         if isfunction(command): | ||||
|             await self._exec_method_and_respond(command, **kwargs) | ||||
|         elif isinstance(command, property): | ||||
|             await self._exec_property_and_respond(command, **kwargs) | ||||
|         else: | ||||
|             self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode()) | ||||
|  | ||||
|     async def listen(self) -> None: | ||||
|         """ | ||||
|         Enters the main control loop listening to client input. | ||||
|  | ||||
|         This method only returns if either the server or the client disconnect. | ||||
|         Messages from the client are read, parsed, and turned into pool commands (if possible). | ||||
|         This method should be called, when the client connection was established and the handshake was successful. | ||||
|         It will obviously block indefinitely. | ||||
|         """ | ||||
|         while self._control_server.is_serving(): | ||||
|             msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() | ||||
|             if not msg: | ||||
|                 log.debug("%s disconnected", self._client_class_name) | ||||
|                 break | ||||
|             await self._parse_command(msg) | ||||
|             await self._writer.drain() | ||||
| @@ -23,6 +23,10 @@ class PoolException(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class PoolIsClosed(PoolException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class PoolIsLocked(PoolException): | ||||
|     pass | ||||
|  | ||||
| @@ -43,7 +47,7 @@ class InvalidTaskID(PoolException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class PoolStillUnlocked(PoolException): | ||||
| class InvalidGroupName(PoolException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @@ -55,13 +59,13 @@ class ServerException(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class UnknownTaskPoolClass(ServerException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class NotATaskPool(ServerException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class HelpRequested(ServerException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class ParserError(ServerException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class CommandError(ServerException): | ||||
|     pass | ||||
|   | ||||
| @@ -1,69 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Miscellaneous helper functions. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from asyncio.coroutines import iscoroutinefunction | ||||
| from asyncio.queues import Queue | ||||
| from inspect import getdoc | ||||
| from typing import Any, Optional, Union | ||||
|  | ||||
| from .types import T, AnyCallableT, ArgsT, KwArgsT | ||||
|  | ||||
|  | ||||
| async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]: | ||||
|     if not callable(function): | ||||
|         return | ||||
|     if kwargs is None: | ||||
|         kwargs = {} | ||||
|     if iscoroutinefunction(function): | ||||
|         return await function(*args, **kwargs) | ||||
|     return function(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: | ||||
|     if arg_stars == 0: | ||||
|         return function(arg) | ||||
|     if arg_stars == 1: | ||||
|         return function(*arg) | ||||
|     if arg_stars == 2: | ||||
|         return function(**arg) | ||||
|     raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") | ||||
|  | ||||
|  | ||||
| async def join_queue(q: Queue) -> None: | ||||
|     await q.join() | ||||
|  | ||||
|  | ||||
| def tasks_str(num: int) -> str: | ||||
|     return "tasks" if num != 1 else "task" | ||||
|  | ||||
|  | ||||
| def get_first_doc_line(obj: object) -> str: | ||||
|     return getdoc(obj).strip().split("\n", 1)[0].strip() | ||||
|  | ||||
|  | ||||
| async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]: | ||||
|     try: | ||||
|         if iscoroutinefunction(_function_to_execute): | ||||
|             return await _function_to_execute(*args, **kwargs) | ||||
|         else: | ||||
|             return _function_to_execute(*args, **kwargs) | ||||
|     except Exception as e: | ||||
|         return e | ||||
							
								
								
									
										0
									
								
								src/asyncio_taskpool/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/asyncio_taskpool/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -16,29 +16,22 @@ If not, see <https://www.gnu.org/licenses/>.""" | ||||
| 
 | ||||
| __doc__ = """ | ||||
| Constants used by more than one module in the package. | ||||
| 
 | ||||
| This module should **not** be considered part of the public API. | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| PACKAGE_NAME = 'asyncio_taskpool' | ||||
| 
 | ||||
| CLIENT_EXIT = 'exit' | ||||
| DEFAULT_TASK_GROUP = 'default' | ||||
| 
 | ||||
| SESSION_MSG_BYTES = 1024 * 100 | ||||
| SESSION_WRITER = 'session_writer' | ||||
| 
 | ||||
| STREAM_WRITER = 'stream_writer' | ||||
| CMD = 'command' | ||||
| CMD_OK = b"ok" | ||||
| 
 | ||||
| 
 | ||||
| class CLIENT_INFO: | ||||
|     __slots__ = () | ||||
|     TERMINAL_WIDTH = 'terminal_width' | ||||
| 
 | ||||
| 
 | ||||
| class CMD: | ||||
|     __slots__ = () | ||||
|     CMD = 'command' | ||||
|     NAME = 'name' | ||||
|     POOL_SIZE = 'pool-size' | ||||
|     NUM_RUNNING = 'num-running' | ||||
|     START = 'start' | ||||
|     STOP = 'stop' | ||||
|     STOP_ALL = 'stop-all' | ||||
|     FUNC_NAME = 'func-name' | ||||
							
								
								
									
										77
									
								
								src/asyncio_taskpool/internals/group_register.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/asyncio_taskpool/internals/group_register.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Definition of :class:`TaskGroupRegister`. | ||||
|  | ||||
| It should not be considered part of the public API. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from asyncio.locks import Lock | ||||
| from collections.abc import MutableSet | ||||
| from typing import Iterator, Set | ||||
|  | ||||
|  | ||||
| class TaskGroupRegister(MutableSet): | ||||
|     """ | ||||
|     Combines the interface of a regular `set` with that of the `asyncio.Lock`. | ||||
|  | ||||
|     Serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for | ||||
|     preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of | ||||
|     tasks, as well as before starting a task within the group. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *task_ids: int) -> None: | ||||
|         self._ids: Set[int] = set(task_ids) | ||||
|         self._lock = Lock() | ||||
|  | ||||
|     def __contains__(self, task_id: int) -> bool: | ||||
|         """Abstract method for the `MutableSet` base class.""" | ||||
|         return task_id in self._ids | ||||
|  | ||||
|     def __iter__(self) -> Iterator[int]: | ||||
|         """Abstract method for the `MutableSet` base class.""" | ||||
|         return iter(self._ids) | ||||
|  | ||||
|     def __len__(self) -> int: | ||||
|         """Abstract method for the `MutableSet` base class.""" | ||||
|         return len(self._ids) | ||||
|  | ||||
|     def add(self, task_id: int) -> None: | ||||
|         """Abstract method for the `MutableSet` base class.""" | ||||
|         self._ids.add(task_id) | ||||
|  | ||||
|     def discard(self, task_id: int) -> None: | ||||
|         """Abstract method for the `MutableSet` base class.""" | ||||
|         self._ids.discard(task_id) | ||||
|  | ||||
|     async def acquire(self) -> bool: | ||||
|         """Wrapper around the lock's `acquire()` method.""" | ||||
|         return await self._lock.acquire() | ||||
|  | ||||
|     def release(self) -> None: | ||||
|         """Wrapper around the lock's `release()` method.""" | ||||
|         self._lock.release() | ||||
|  | ||||
|     async def __aenter__(self) -> None: | ||||
|         """Provides the asynchronous context manager syntax `async with ... :` when using the lock.""" | ||||
|         await self._lock.acquire() | ||||
|         return None | ||||
|  | ||||
|     async def __aexit__(self, exc_type, exc, tb) -> None: | ||||
|         """Provides the asynchronous context manager syntax `async with ... :` when using the lock.""" | ||||
|         self._lock.release() | ||||
							
								
								
									
										133
									
								
								src/asyncio_taskpool/internals/helpers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								src/asyncio_taskpool/internals/helpers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,133 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Miscellaneous helper functions. None of these should be considered part of the public API. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from asyncio.coroutines import iscoroutinefunction | ||||
| from importlib import import_module | ||||
| from inspect import getdoc | ||||
| from typing import Any, Optional, Union | ||||
|  | ||||
| from .types import T, AnyCallableT, ArgsT, KwArgsT | ||||
|  | ||||
|  | ||||
| async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]: | ||||
|     """ | ||||
|     Runs `function` with `args` and `kwargs` and returns its output. | ||||
|  | ||||
|     Args: | ||||
|         function: | ||||
|             Any callable that accepts the provided positional and keyword-arguments. | ||||
|             If it is a coroutine function, it will be awaited. | ||||
|             If it is not a callable, nothing is returned. | ||||
|         *args (optional): | ||||
|             Positional arguments to pass to `function`. | ||||
|         **kwargs (optional): | ||||
|             Keyword-arguments to pass to `function`. | ||||
|  | ||||
|     Returns: | ||||
|         Whatever `function` returns (possibly after being awaited) or `None` if `function` is not callable. | ||||
|     """ | ||||
|     if not callable(function): | ||||
|         return | ||||
|     if kwargs is None: | ||||
|         kwargs = {} | ||||
|     if iscoroutinefunction(function): | ||||
|         return await function(*args, **kwargs) | ||||
|     return function(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T: | ||||
|     """ | ||||
|     Calls `function` passing `arg` to it, optionally unpacking it first. | ||||
|  | ||||
|     Args: | ||||
|         function: | ||||
|             Any callable that accepts the provided argument(s). | ||||
|         arg: | ||||
|             The single positional argument that `function` expects; in this case `arg_stars` should be 0. | ||||
|             Or the iterable of positional arguments that `function` expects; in this case `arg_stars` should be 1. | ||||
|             Or the mapping of keyword-arguments that `function` expects; in this case `arg_stars` should be 2. | ||||
|         arg_stars (optional): | ||||
|             Determines if and how to unpack `arg`. | ||||
|             0 means no unpacking, i.e. `arg` is passed into `function` directly as `function(arg)`. | ||||
|             1 means unpacking to an arbitrary number of positional arguments, i.e. as `function(*arg)`. | ||||
|             2 means unpacking to an arbitrary number of keyword-arguments, i.e. as `function(**arg)`. | ||||
|  | ||||
|     Returns: | ||||
|         Whatever `function` returns. | ||||
|  | ||||
|     Raises: | ||||
|         `ValueError`: `arg_stars` is something other than 0, 1, or 2. | ||||
|     """ | ||||
|     if arg_stars == 0: | ||||
|         return function(arg) | ||||
|     if arg_stars == 1: | ||||
|         return function(*arg) | ||||
|     if arg_stars == 2: | ||||
|         return function(**arg) | ||||
|     raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.") | ||||
|  | ||||
|  | ||||
| def get_first_doc_line(obj: object) -> str: | ||||
|     """Takes an object and returns the first (non-empty) line of its docstring.""" | ||||
|     return getdoc(obj).strip().split("\n", 1)[0].strip() | ||||
|  | ||||
|  | ||||
| async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]: | ||||
|     """ | ||||
|     Returns the output of a function or the exception thrown during its execution. | ||||
|  | ||||
|     Args: | ||||
|         _function_to_execute: | ||||
|             Any callable that accepts the provided positional and keyword-arguments. | ||||
|         *args (optional): | ||||
|             Positional arguments to pass to `_function_to_execute`. | ||||
|         **kwargs (optional): | ||||
|             Keyword-arguments to pass to `_function_to_execute`. | ||||
|  | ||||
|     Returns: | ||||
|         Whatever `_function_to_execute` returns or throws. (An exception is not raised, but returned!) | ||||
|     """ | ||||
|     try: | ||||
|         if iscoroutinefunction(_function_to_execute): | ||||
|             return await _function_to_execute(*args, **kwargs) | ||||
|         else: | ||||
|             return _function_to_execute(*args, **kwargs) | ||||
|     except Exception as e: | ||||
|         return e | ||||
|  | ||||
|  | ||||
| def resolve_dotted_path(dotted_path: str) -> object: | ||||
|     """ | ||||
|     Resolves a dotted path to a global object and returns that object. | ||||
|  | ||||
|     Algorithm shamelessly stolen from the `logging.config` module from the standard library. | ||||
|     """ | ||||
|     names = dotted_path.split('.') | ||||
|     module_name = names.pop(0) | ||||
|     found = import_module(module_name) | ||||
|     for name in names: | ||||
|         try: | ||||
|             found = getattr(found, name) | ||||
|         except AttributeError: | ||||
|             module_name += f'.{name}' | ||||
|             import_module(module_name) | ||||
|             found = getattr(found, name) | ||||
|     return found | ||||
| @@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>.""" | ||||
| 
 | ||||
| __doc__ = """ | ||||
| Custom type definitions used in various modules. | ||||
| 
 | ||||
| This module should **not** be considered part of the public API. | ||||
| """ | ||||
| 
 | ||||
| 
 | ||||
| @@ -32,8 +34,8 @@ KwArgsT = Mapping[str, Any] | ||||
| AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] | ||||
| CoroutineFunc = Callable[[...], Awaitable[Any]] | ||||
| 
 | ||||
| EndCallbackT = Callable | ||||
| CancelCallbackT = Callable | ||||
| EndCB = Callable | ||||
| CancelCB = Callable | ||||
| 
 | ||||
| ConnectedCallbackT = Callable[[StreamReader, StreamWriter], Awaitable[None]] | ||||
| ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]] | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										66
									
								
								src/asyncio_taskpool/queue_context.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								src/asyncio_taskpool/queue_context.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Definition of an :code:`asyncio.Queue` subclass with some small additions. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from asyncio.queues import Queue as _Queue | ||||
| from typing import Any | ||||
|  | ||||
|  | ||||
| __all__ = ['Queue'] | ||||
|  | ||||
|  | ||||
| class Queue(_Queue): | ||||
|     """ | ||||
|     Adds a little syntactic sugar to the :code:`asyncio.Queue`. | ||||
|  | ||||
|     Allows being used as an async context manager awaiting `get` upon entering the context and calling | ||||
|     :meth:`item_processed` upon exiting it. | ||||
|     """ | ||||
|  | ||||
|     def item_processed(self) -> None: | ||||
|         """ | ||||
|         Does exactly the same as :meth:`asyncio.Queue.task_done`. | ||||
|  | ||||
|         This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing, | ||||
|         invites confusion, and immensely reduces readability (in the context of this library). And readability counts. | ||||
|         """ | ||||
|         self.task_done() | ||||
|  | ||||
|     async def __aenter__(self) -> Any: | ||||
|         """ | ||||
|         Implements an asynchronous context manager for the queue. | ||||
|  | ||||
|         Upon entering :meth:`get` is awaited and subsequently whatever came out of the queue is returned. | ||||
|         It allows writing code this way: | ||||
|         >>> queue = Queue() | ||||
|         >>> ... | ||||
|         >>> async with queue as item: | ||||
|         >>>     ... | ||||
|         """ | ||||
|         return await self.get() | ||||
|  | ||||
|     async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: | ||||
|         """ | ||||
|         Implements an asynchronous context manager for the queue. | ||||
|  | ||||
|         Upon exiting :meth:`item_processed` is called. This is why this context manager may not always be what you want, | ||||
|         but in some situations it makes the code much cleaner. | ||||
|         """ | ||||
|         self.item_processed() | ||||
| @@ -1,304 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| This module contains the the definition of the control session class used by the control server. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import logging | ||||
| import json | ||||
| from argparse import ArgumentError, HelpFormatter | ||||
| from asyncio.streams import StreamReader, StreamWriter | ||||
| from typing import Callable, Optional, Union, TYPE_CHECKING | ||||
|  | ||||
| from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO | ||||
| from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass | ||||
| from .helpers import get_first_doc_line, return_or_exception, tasks_str | ||||
| from .pool import BaseTaskPool, TaskPool, SimpleTaskPool | ||||
| from .session_parser import CommandParser, NUM | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from .server import ControlServer | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class ControlSession: | ||||
|     """ | ||||
|     This class defines the API for controlling a task pool instance from the outside. | ||||
|  | ||||
|     The commands received from a connected client are translated into method calls on the task pool instance. | ||||
|     A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: | ||||
|         """ | ||||
|         Instantiation should happen once a client connection to the control server has already been established. | ||||
|  | ||||
|         For more convenient/efficient access, some of the server's properties are saved in separate attributes. | ||||
|         The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during | ||||
|         initialization, which is obtained in the `client_handshake` method; only there is the parser fully configured. | ||||
|  | ||||
|         Args: | ||||
|             server: | ||||
|                 The instance of a `ControlServer` subclass starting the session. | ||||
|             reader: | ||||
|                 The `asyncio.StreamReader` created when a client connected to the server. | ||||
|             writer: | ||||
|                 The `asyncio.StreamWriter` created when a client connected to the server. | ||||
|         """ | ||||
|         self._control_server: 'ControlServer' = server | ||||
|         self._pool: Union[TaskPool, SimpleTaskPool] = server.pool | ||||
|         self._client_class_name = server.client_class_name | ||||
|         self._reader: StreamReader = reader | ||||
|         self._writer: StreamWriter = writer | ||||
|         self._parser: Optional[CommandParser] = None | ||||
|         self._subparsers = None | ||||
|  | ||||
|     def _add_command(self, name: str, prog: str = None, short_help: str = None, long_help: str = None, | ||||
|                      **kwargs) -> CommandParser: | ||||
|         """ | ||||
|         Convenience method for adding a subparser (i.e. another command) to the main `CommandParser` instance. | ||||
|  | ||||
|         Will always pass the session's main `CommandParser` instance as the `parent` keyword-argument. | ||||
|  | ||||
|         Args: | ||||
|             name: | ||||
|                 The command name; passed directly into the `add_parser` method. | ||||
|             prog (optional): | ||||
|                 Also passed into the `add_parser` method as the corresponding keyword-argument. By default, is set | ||||
|                 equal to the `name` argument. | ||||
|             short_help (optional): | ||||
|                 Passed into the `add_parser` method as the `help` keyword-argument, unless it is left empty and the | ||||
|                 `long_help` argument is present; in that case the `long_help` argument is passed as `help`. | ||||
|             long_help (optional): | ||||
|                 Passed into the `add_parser` method as the `description` keyword-argument, unless it is left empty and | ||||
|                 the `short_help` argument is present; in that case the `short_help` argument is passed as `description`. | ||||
|             **kwargs (optional): | ||||
|                 Any keyword-arguments to directly pass into the `add_parser` method. | ||||
|  | ||||
|         Returns: | ||||
|             An instance of the `CommandParser` class representing the newly added control command. | ||||
|         """ | ||||
|         if prog is None: | ||||
|             prog = name | ||||
|         kwargs.setdefault('help', short_help or long_help) | ||||
|         kwargs.setdefault('description', long_help or short_help) | ||||
|         return self._subparsers.add_parser(name, prog=prog, parent=self._parser, **kwargs) | ||||
|  | ||||
|     def _add_base_commands(self) -> None: | ||||
|         """ | ||||
|         Adds the commands that are supported regardless of the specific subclass of `BaseTaskPool` controlled. | ||||
|  | ||||
|         These include commands mapping to the following pool methods: | ||||
|             - __str__ | ||||
|             - pool_size (get/set property) | ||||
|             - num_running | ||||
|         """ | ||||
|         self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__)) | ||||
|         self._add_command( | ||||
|             CMD.POOL_SIZE,  | ||||
|             short_help="Get/set the maximum number of tasks in the pool.",  | ||||
|             formatter_class=HelpFormatter | ||||
|         ).add_optional_num_argument( | ||||
|             default=None, | ||||
|             help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} " | ||||
|                  f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}" | ||||
|         ) | ||||
|         self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget)) | ||||
|  | ||||
|     def _add_simple_commands(self) -> None: | ||||
|         """ | ||||
|         Adds the commands that are only supported, if a `SimpleTaskPool` object is controlled. | ||||
|  | ||||
|         These include commands mapping to the following pool methods: | ||||
|             - start | ||||
|             - stop | ||||
|             - stop_all | ||||
|             - func_name | ||||
|         """ | ||||
|         self._add_command( | ||||
|             CMD.START, short_help=get_first_doc_line(self._pool.__class__.start) | ||||
|         ).add_optional_num_argument( | ||||
|             help="Number of tasks to start." | ||||
|         ) | ||||
|         self._add_command( | ||||
|             CMD.STOP, short_help=get_first_doc_line(self._pool.__class__.stop) | ||||
|         ).add_optional_num_argument( | ||||
|             help="Number of tasks to stop." | ||||
|         ) | ||||
|         self._add_command(CMD.STOP_ALL, short_help=get_first_doc_line(self._pool.__class__.stop_all)) | ||||
|         self._add_command(CMD.FUNC_NAME, short_help=get_first_doc_line(self._pool.__class__.func_name.fget)) | ||||
|  | ||||
|     def _add_advanced_commands(self) -> None: | ||||
|         """ | ||||
|         Adds the commands that are only supported, if a `TaskPool` object is controlled. | ||||
|  | ||||
|         These include commands mapping to the following pool methods: | ||||
|             - ... | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _init_parser(self, client_terminal_width: int) -> None: | ||||
|         """ | ||||
|         Initializes and fully configures the `CommandParser` responsible for handling the input. | ||||
|  | ||||
|         Depending on what specific task pool class is controlled by the server, different commands are added. | ||||
|  | ||||
|         Args: | ||||
|             client_terminal_width: | ||||
|                 The number of columns of the client's terminal to be able to nicely format messages from the parser. | ||||
|         """ | ||||
|         parser_kwargs = { | ||||
|             'prog': '', | ||||
|             SESSION_WRITER: self._writer, | ||||
|             CLIENT_INFO.TERMINAL_WIDTH: client_terminal_width, | ||||
|         } | ||||
|         self._parser = CommandParser(**parser_kwargs) | ||||
|         self._subparsers = self._parser.add_subparsers(title="Commands", dest=CMD.CMD) | ||||
|         self._add_base_commands() | ||||
|         if isinstance(self._pool, TaskPool): | ||||
|             self._add_advanced_commands() | ||||
|         elif isinstance(self._pool, SimpleTaskPool): | ||||
|             self._add_simple_commands() | ||||
|         elif isinstance(self._pool, BaseTaskPool): | ||||
|             raise UnknownTaskPoolClass(f"No interface defined for {self._pool.__class__.__name__}") | ||||
|         else: | ||||
|             raise NotATaskPool(f"Not a task pool instance: {self._pool}") | ||||
|  | ||||
|     async def client_handshake(self) -> None: | ||||
|         """ | ||||
|         This method must be invoked before starting any other client interaction. | ||||
|  | ||||
|         Client info is retrieved, server info is sent back, and the `CommandParser` is initialized and configured. | ||||
|         """ | ||||
|         client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) | ||||
|         log.debug("%s connected", self._client_class_name) | ||||
|         self._init_parser(client_info[CLIENT_INFO.TERMINAL_WIDTH]) | ||||
|         self._writer.write(str(self._pool).encode()) | ||||
|         await self._writer.drain() | ||||
|  | ||||
|     async def _write_function_output(self, func: Callable, *args, **kwargs) -> None: | ||||
|         """ | ||||
|         Acts as a wrapper around a call to a specific task pool method. | ||||
|  | ||||
|         The method is called and any exception is caught and saved. If there is no output and no exception caught, a | ||||
|         generic confirmation message is sent back to the client. Otherwise the output or a string representation of | ||||
|         the exception caught is sent back. | ||||
|  | ||||
|         Args: | ||||
|             func: | ||||
|                 Reference to the task pool method. | ||||
|             *args (optional): | ||||
|                 Any positional arguments to call the method with. | ||||
|             *+kwargs (optional): | ||||
|                 Any keyword-arguments to call the method with. | ||||
|         """ | ||||
|         output = await return_or_exception(func, *args, **kwargs) | ||||
|         self._writer.write(b"ok" if output is None else str(output).encode()) | ||||
|  | ||||
|     async def _cmd_name(self, **_kwargs) -> None: | ||||
|         """Maps to the `__str__` method of any task pool class.""" | ||||
|         log.debug("%s requests task pool name", self._client_class_name) | ||||
|         await self._write_function_output(self._pool.__class__.__str__, self._pool) | ||||
|  | ||||
|     async def _cmd_pool_size(self, **kwargs) -> None: | ||||
|         """Maps to the `pool_size` property of any task pool class.""" | ||||
|         num = kwargs.get(NUM) | ||||
|         if num is None: | ||||
|             log.debug("%s requests pool size", self._client_class_name) | ||||
|             await self._write_function_output(self._pool.__class__.pool_size.fget, self._pool) | ||||
|         else: | ||||
|             log.debug("%s requests setting pool size to %s", self._client_class_name, num) | ||||
|             await self._write_function_output(self._pool.__class__.pool_size.fset, self._pool, num) | ||||
|  | ||||
|     async def _cmd_num_running(self, **_kwargs) -> None: | ||||
|         """Maps to the `num_running` property of any task pool class.""" | ||||
|         log.debug("%s requests number of running tasks", self._client_class_name) | ||||
|         await self._write_function_output(self._pool.__class__.num_running.fget, self._pool) | ||||
|  | ||||
|     async def _cmd_start(self, **kwargs) -> None: | ||||
|         """Maps to the `start` method of the `SimpleTaskPool` class.""" | ||||
|         num = kwargs[NUM] | ||||
|         log.debug("%s requests starting %s %s", self._client_class_name, num, tasks_str(num)) | ||||
|         await self._write_function_output(self._pool.start, num) | ||||
|  | ||||
|     async def _cmd_stop(self, **kwargs) -> None: | ||||
|         """Maps to the `stop` method of the `SimpleTaskPool` class.""" | ||||
|         num = kwargs[NUM] | ||||
|         log.debug("%s requests stopping %s %s", self._client_class_name, num, tasks_str(num)) | ||||
|         await self._write_function_output(self._pool.stop, num) | ||||
|  | ||||
|     async def _cmd_stop_all(self, **_kwargs) -> None: | ||||
|         """Maps to the `stop_all` method of the `SimpleTaskPool` class.""" | ||||
|         log.debug("%s requests stopping all tasks", self._client_class_name) | ||||
|         await self._write_function_output(self._pool.stop_all) | ||||
|  | ||||
|     async def _cmd_func_name(self, **_kwargs) -> None: | ||||
|         """Maps to the `func_name` method of the `SimpleTaskPool` class.""" | ||||
|         log.debug("%s requests pool function name", self._client_class_name) | ||||
|         await self._write_function_output(self._pool.__class__.func_name.fget, self._pool) | ||||
|  | ||||
|     async def _execute_command(self, **kwargs) -> None: | ||||
|         """ | ||||
|         Dynamically gets the correct `_cmd_...` method depending on the name of the command passed and executes it. | ||||
|  | ||||
|         Args: | ||||
|             **kwargs: | ||||
|                 Must include the `CMD.CMD` key mapping the the command name. The rest of the keyword-arguments is | ||||
|                 simply passed into the method determined from the command name. | ||||
|         """ | ||||
|         method = getattr(self, f'_cmd_{kwargs.pop(CMD.CMD).replace("-", "_")}') | ||||
|         await method(**kwargs) | ||||
|  | ||||
|     async def _parse_command(self, msg: str) -> None: | ||||
|         """ | ||||
|         Takes a message from the client and attempts to parse it. | ||||
|  | ||||
|         If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the | ||||
|         `CommandParser`, nothing else happens. Otherwise, the `_execute_command` method is called with the entire | ||||
|         dictionary of keyword-arguments returned by the `CommandParser` passed into it. | ||||
|  | ||||
|         Args: | ||||
|             msg: | ||||
|                 The non-empty string read from the client stream. | ||||
|         """ | ||||
|         try: | ||||
|             kwargs = vars(self._parser.parse_args(msg.split(' '))) | ||||
|         except ArgumentError as e: | ||||
|             self._writer.write(str(e).encode()) | ||||
|             return | ||||
|         except HelpRequested: | ||||
|             return | ||||
|         await self._execute_command(**kwargs) | ||||
|  | ||||
|     async def listen(self) -> None: | ||||
|         """ | ||||
|         Enters the main control loop that only ends if either the server or the client disconnect. | ||||
|  | ||||
|         Messages from the client are read and passed into the `_parse_command` method, which handles the rest. | ||||
|         This method should be called, when the client connection was established and the handshake was successful. | ||||
|         It will obviously block indefinitely. | ||||
|         """ | ||||
|         while self._control_server.is_serving(): | ||||
|             msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() | ||||
|             if not msg: | ||||
|                 log.debug("%s disconnected", self._client_class_name) | ||||
|                 break | ||||
|             await self._parse_command(msg) | ||||
|             await self._writer.drain() | ||||
| @@ -1,127 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| This module contains the the definition of the `CommandParser` class used in a control server session. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter | ||||
| from asyncio.streams import StreamWriter | ||||
| from typing import Type, TypeVar | ||||
|  | ||||
| from .constants import SESSION_WRITER, CLIENT_INFO | ||||
| from .exceptions import HelpRequested | ||||
|  | ||||
|  | ||||
| FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) | ||||
| FORMATTER_CLASS = 'formatter_class' | ||||
| NUM = 'num' | ||||
|  | ||||
|  | ||||
| class CommandParser(ArgumentParser): | ||||
|     """ | ||||
|     Subclass of the standard `argparse.ArgumentParser` for remote interaction. | ||||
|  | ||||
|     Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` | ||||
|     instance passed to it during initialization. | ||||
|     Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a | ||||
|     connected client. | ||||
|     Finally, it offers some convenience methods and makes use of custom exceptions. | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: | ||||
|         """ | ||||
|         Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument. | ||||
|  | ||||
|         Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not | ||||
|         as convenient, when making use of sub-parsers. | ||||
|  | ||||
|         Args: | ||||
|             terminal_width: | ||||
|                 The number of columns of the terminal to which to adjust help formatting. | ||||
|             base_cls (optional): | ||||
|                 The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used. | ||||
|  | ||||
|         Returns: | ||||
|             The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. | ||||
|         """ | ||||
|         if base_cls is None: | ||||
|             base_cls = ArgumentDefaultsHelpFormatter | ||||
|  | ||||
|         class ClientHelpFormatter(base_cls): | ||||
|             def __init__(self, *args, **kwargs) -> None: | ||||
|                 kwargs['width'] = terminal_width | ||||
|                 super().__init__(*args, **kwargs) | ||||
|         return ClientHelpFormatter | ||||
|  | ||||
|     def __init__(self, parent: 'CommandParser' = None, **kwargs) -> None: | ||||
|         """ | ||||
|         Sets additional internal attributes depending on whether a parent-parser was defined. | ||||
|  | ||||
|         The `help_formatter_factory` is called and the returned class is mapped to the `FORMATTER_CLASS` keyword. | ||||
|         By default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it). | ||||
|  | ||||
|         Args: | ||||
|             parent (optional): | ||||
|                 An instance of the same class. Intended to be passed as a keyword-argument into the `add_parser` method | ||||
|                 of the subparsers action returned by the `ArgumentParser.add_subparsers` method. If this is present, | ||||
|                 the `SESSION_WRITER` and `CLIENT_INFO.TERMINAL_WIDTH` keywords must not be present in `kwargs`. | ||||
|             **kwargs(optional): | ||||
|                 In addition to the regular `ArgumentParser` constructor parameters, this method expects the instance of | ||||
|                 the `StreamWriter` as well as the terminal width both to be passed explicitly, if the `parent` argument | ||||
|                 is empty. | ||||
|         """ | ||||
|         self._session_writer: StreamWriter = parent.session_writer if parent else kwargs.pop(SESSION_WRITER) | ||||
|         self._terminal_width: int = parent.terminal_width if parent else kwargs.pop(CLIENT_INFO.TERMINAL_WIDTH) | ||||
|         kwargs[FORMATTER_CLASS] = self.help_formatter_factory(self._terminal_width, kwargs.get(FORMATTER_CLASS)) | ||||
|         kwargs.setdefault('exit_on_error', False) | ||||
|         super().__init__(**kwargs) | ||||
|  | ||||
|     @property | ||||
|     def session_writer(self) -> StreamWriter: | ||||
|         """Returns the predefined stream writer object of the control session.""" | ||||
|         return self._session_writer | ||||
|  | ||||
|     @property | ||||
|     def terminal_width(self) -> int: | ||||
|         """Returns the predefined terminal width.""" | ||||
|         return self._terminal_width | ||||
|  | ||||
|     def _print_message(self, message: str, *args, **kwargs) -> None: | ||||
|         """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" | ||||
|         if message: | ||||
|             self._session_writer.write(message.encode()) | ||||
|  | ||||
|     def exit(self, status: int = 0, message: str = None) -> None: | ||||
|         """This is overridden to prevent system exit to be invoked.""" | ||||
|         if message: | ||||
|             self._print_message(message) | ||||
|  | ||||
|     def print_help(self, file=None) -> None: | ||||
|         """This just adds the custom `HelpRequested` exception after the parent class' method.""" | ||||
|         super().print_help(file) | ||||
|         raise HelpRequested | ||||
|  | ||||
|     def add_optional_num_argument(self, *name_or_flags: str, **kwargs) -> Action: | ||||
|         """Convenience method for `add_argument` setting the name, `nargs`, `default`, and `type`, unless specified.""" | ||||
|         if not name_or_flags: | ||||
|             name_or_flags = (NUM, ) | ||||
|         kwargs.setdefault('nargs', '?') | ||||
|         kwargs.setdefault('default', 1) | ||||
|         kwargs.setdefault('type', int) | ||||
|         return self.add_argument(*name_or_flags, **kwargs) | ||||
							
								
								
									
										0
									
								
								tests/test_control/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/test_control/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										45
									
								
								tests/test_control/test___main__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								tests/test_control/test___main__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| from pathlib import Path | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import AsyncMock, MagicMock, patch | ||||
|  | ||||
| from asyncio_taskpool.control.client import TCPControlClient, UnixControlClient | ||||
| from asyncio_taskpool.control import __main__ as module | ||||
|  | ||||
|  | ||||
| class CLITestCase(IsolatedAsyncioTestCase): | ||||
|  | ||||
|     def test_parse_cli(self): | ||||
|         socket_path = '/some/path/to.sock' | ||||
|         args = [module.UNIX, socket_path] | ||||
|         expected_kwargs = { | ||||
|             module.CLIENT_CLASS: UnixControlClient, | ||||
|             module.SOCKET_PATH: Path(socket_path) | ||||
|         } | ||||
|         parsed_kwargs = module.parse_cli(args) | ||||
|         self.assertDictEqual(expected_kwargs, parsed_kwargs) | ||||
|  | ||||
|         host, port = '1.2.3.4', '1234' | ||||
|         args = [module.TCP, host, port] | ||||
|         expected_kwargs = { | ||||
|             module.CLIENT_CLASS: TCPControlClient, | ||||
|             module.HOST: host, | ||||
|             module.PORT: int(port) | ||||
|         } | ||||
|         parsed_kwargs = module.parse_cli(args) | ||||
|         self.assertDictEqual(expected_kwargs, parsed_kwargs) | ||||
|  | ||||
|         with patch('sys.stderr'): | ||||
|             with self.assertRaises(SystemExit): | ||||
|                 module.parse_cli(['invalid', 'foo', 'bar']) | ||||
|  | ||||
|     @patch.object(module, 'parse_cli') | ||||
|     async def test_main(self, mock_parse_cli: MagicMock): | ||||
|         mock_client_start = AsyncMock() | ||||
|         mock_client = MagicMock(start=mock_client_start) | ||||
|         mock_client_cls = MagicMock(return_value=mock_client) | ||||
|         mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789} | ||||
|         mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs | ||||
|         self.assertIsNone(await module.main()) | ||||
|         mock_parse_cli.assert_called_once_with() | ||||
|         mock_client_cls.assert_called_once_with(**mock_client_kwargs) | ||||
|         mock_client_start.assert_awaited_once_with() | ||||
| @@ -20,14 +20,15 @@ Unittests for the `asyncio_taskpool.client` module. | ||||
| 
 | ||||
| 
 | ||||
| import json | ||||
| import os | ||||
| import shutil | ||||
| import sys | ||||
| from pathlib import Path | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import AsyncMock, MagicMock, patch | ||||
| from unittest import IsolatedAsyncioTestCase, skipIf | ||||
| from unittest.mock import AsyncMock, MagicMock, call, patch | ||||
| 
 | ||||
| from asyncio_taskpool import client | ||||
| from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES | ||||
| from asyncio_taskpool.control import client | ||||
| from asyncio_taskpool.internals.constants import CLIENT_INFO, SESSION_MSG_BYTES | ||||
| 
 | ||||
| 
 | ||||
| FOO, BAR = 'foo', 'bar' | ||||
| @@ -36,7 +37,7 @@ FOO, BAR = 'foo', 'bar' | ||||
| class ControlClientTestCase(IsolatedAsyncioTestCase): | ||||
| 
 | ||||
|     def setUp(self) -> None: | ||||
|         self.abstract_patcher = patch('asyncio_taskpool.client.ControlClient.__abstractmethods__', set()) | ||||
|         self.abstract_patcher = patch('asyncio_taskpool.control.client.ControlClient.__abstractmethods__', set()) | ||||
|         self.print_patcher = patch.object(client, 'print') | ||||
|         self.mock_abstract_methods = self.abstract_patcher.start() | ||||
|         self.mock_print = self.print_patcher.start() | ||||
| @@ -54,7 +55,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): | ||||
| 
 | ||||
|     def test_client_info(self): | ||||
|         self.assertEqual({CLIENT_INFO.TERMINAL_WIDTH: shutil.get_terminal_size().columns}, | ||||
|                          client.ControlClient.client_info()) | ||||
|                          client.ControlClient._client_info()) | ||||
| 
 | ||||
|     async def test_abstract(self): | ||||
|         with self.assertRaises(NotImplementedError): | ||||
| @@ -64,16 +65,19 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): | ||||
|         self.assertEqual(self.kwargs, self.client._conn_kwargs) | ||||
|         self.assertFalse(self.client._connected) | ||||
| 
 | ||||
|     @patch.object(client.ControlClient, 'client_info') | ||||
|     async def test__server_handshake(self, mock_client_info: MagicMock): | ||||
|         mock_client_info.return_value = mock_info = {FOO: 1, BAR: 9999} | ||||
|     @patch.object(client.ControlClient, '_client_info') | ||||
|     async def test__server_handshake(self, mock__client_info: MagicMock): | ||||
|         mock__client_info.return_value = mock_info = {FOO: 1, BAR: 9999} | ||||
|         self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer)) | ||||
|         self.assertTrue(self.client._connected) | ||||
|         mock_client_info.assert_called_once_with() | ||||
|         mock__client_info.assert_called_once_with() | ||||
|         self.mock_write.assert_called_once_with(json.dumps(mock_info).encode()) | ||||
|         self.mock_drain.assert_awaited_once_with() | ||||
|         self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) | ||||
|         self.mock_print.assert_called_once_with("Connected to", self.mock_read.return_value.decode()) | ||||
|         self.mock_print.assert_has_calls([ | ||||
|             call("Connected to", self.mock_read.return_value.decode()), | ||||
|             call("Type '-h' to get help and usage instructions for all available commands.\n") | ||||
|         ]) | ||||
| 
 | ||||
|     @patch.object(client, 'input') | ||||
|     def test__get_command(self, mock_input: MagicMock): | ||||
| @@ -171,6 +175,44 @@ class ControlClientTestCase(IsolatedAsyncioTestCase): | ||||
|         self.mock_print.assert_called_once_with("Disconnected from control server.") | ||||
| 
 | ||||
| 
 | ||||
| class TCPControlClientTestCase(IsolatedAsyncioTestCase): | ||||
| 
 | ||||
|     def setUp(self) -> None: | ||||
|         self.base_init_patcher = patch.object(client.ControlClient, '__init__') | ||||
|         self.mock_base_init = self.base_init_patcher.start() | ||||
|         self.host, self.port = 'localhost', 12345 | ||||
|         self.kwargs = {FOO: 123, BAR: 456} | ||||
|         self.client = client.TCPControlClient(host=self.host, port=self.port, **self.kwargs) | ||||
| 
 | ||||
|     def tearDown(self) -> None: | ||||
|         self.base_init_patcher.stop() | ||||
| 
 | ||||
|     def test_init(self): | ||||
|         self.assertEqual(self.host, self.client._host) | ||||
|         self.assertEqual(self.port, self.client._port) | ||||
|         self.mock_base_init.assert_called_once_with(**self.kwargs) | ||||
| 
 | ||||
|     @patch.object(client, 'print') | ||||
|     @patch.object(client, 'open_connection') | ||||
|     async def test__open_connection(self, mock_open_connection: AsyncMock, mock_print: MagicMock): | ||||
|         mock_open_connection.return_value = expected_output = 'something' | ||||
|         kwargs = {'a': 1, 'b': 2} | ||||
|         output = await self.client._open_connection(**kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs) | ||||
|         mock_print.assert_not_called() | ||||
| 
 | ||||
|         mock_open_connection.reset_mock() | ||||
| 
 | ||||
|         mock_open_connection.side_effect = e = ConnectionError() | ||||
|         output1, output2 = await self.client._open_connection(**kwargs) | ||||
|         self.assertIsNone(output1) | ||||
|         self.assertIsNone(output2) | ||||
|         mock_open_connection.assert_awaited_once_with(self.host, self.port, **kwargs) | ||||
|         mock_print.assert_called_once_with(str(e), file=sys.stderr) | ||||
| 
 | ||||
| 
 | ||||
| @skipIf(os.name == 'nt', "No Unix sockets on Windows :(") | ||||
| class UnixControlClientTestCase(IsolatedAsyncioTestCase): | ||||
| 
 | ||||
|     def setUp(self) -> None: | ||||
| @@ -188,9 +230,9 @@ class UnixControlClientTestCase(IsolatedAsyncioTestCase): | ||||
|         self.mock_base_init.assert_called_once_with(**self.kwargs) | ||||
| 
 | ||||
|     @patch.object(client, 'print') | ||||
|     @patch.object(client, 'open_unix_connection') | ||||
|     async def test__open_connection(self, mock_open_unix_connection: AsyncMock, mock_print: MagicMock): | ||||
|         mock_open_unix_connection.return_value = expected_output = 'something' | ||||
|     async def test__open_connection(self, mock_print: MagicMock): | ||||
|         expected_output = 'something' | ||||
|         self.client._open_unix_connection = mock_open_unix_connection = AsyncMock(return_value=expected_output) | ||||
|         kwargs = {'a': 1, 'b': 2} | ||||
|         output = await self.client._open_connection(**kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
							
								
								
									
										313
									
								
								tests/test_control/test_parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										313
									
								
								tests/test_control/test_parser.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,313 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.control.parser` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from argparse import ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, SUPPRESS | ||||
| from ast import literal_eval | ||||
| from inspect import signature | ||||
| from unittest import TestCase | ||||
| from unittest.mock import MagicMock, call, patch | ||||
| from typing import Iterable | ||||
|  | ||||
| from asyncio_taskpool.control import parser | ||||
| from asyncio_taskpool.exceptions import HelpRequested, ParserError | ||||
| from asyncio_taskpool.internals.helpers import resolve_dotted_path | ||||
| from asyncio_taskpool.internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT | ||||
|  | ||||
|  | ||||
| FOO, BAR = 'foo', 'bar' | ||||
|  | ||||
|  | ||||
| class ControlParserTestCase(TestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') | ||||
|         self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() | ||||
|         self.mock_help_formatter_factory.return_value = RawTextHelpFormatter | ||||
|         self.stream_writer, self.terminal_width = MagicMock(), 420 | ||||
|         self.kwargs = { | ||||
|             'stream_writer': self.stream_writer, | ||||
|             'terminal_width': self.terminal_width, | ||||
|             'formatter_class': FOO | ||||
|         } | ||||
|         self.parser = parser.ControlParser(**self.kwargs) | ||||
|  | ||||
|     def tearDown(self) -> None: | ||||
|         self.help_formatter_factory_patcher.stop() | ||||
|  | ||||
|     def test_help_formatter_factory(self): | ||||
|         self.help_formatter_factory_patcher.stop() | ||||
|  | ||||
|         class MockBaseClass(HelpFormatter): | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 super().__init__(*args, **kwargs) | ||||
|  | ||||
|         terminal_width = 123456789 | ||||
|         cls = parser.ControlParser.help_formatter_factory(terminal_width, MockBaseClass) | ||||
|         self.assertTrue(issubclass(cls, MockBaseClass)) | ||||
|         instance = cls('prog') | ||||
|         self.assertEqual(terminal_width, getattr(instance, '_width')) | ||||
|  | ||||
|         cls = parser.ControlParser.help_formatter_factory(terminal_width) | ||||
|         self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter)) | ||||
|         instance = cls('prog') | ||||
|         self.assertEqual(terminal_width, getattr(instance, '_width')) | ||||
|  | ||||
|     def test_init(self): | ||||
|         self.assertIsInstance(self.parser, ArgumentParser) | ||||
|         self.assertEqual(self.stream_writer, self.parser._stream_writer) | ||||
|         self.assertEqual(self.terminal_width, self.parser._terminal_width) | ||||
|         self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) | ||||
|         self.assertFalse(getattr(self.parser, 'exit_on_error')) | ||||
|         self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) | ||||
|         self.assertSetEqual(set(), self.parser._flags) | ||||
|         self.assertIsNone(self.parser._commands) | ||||
|  | ||||
|     @patch.object(parser, 'get_first_doc_line') | ||||
|     def test_add_function_command(self, mock_get_first_doc_line: MagicMock): | ||||
|         def foo_bar(): pass | ||||
|         mock_subparser = MagicMock() | ||||
|         mock_add_parser = MagicMock(return_value=mock_subparser) | ||||
|         self.parser._commands = MagicMock(add_parser=mock_add_parser) | ||||
|         mock_get_first_doc_line.return_value = mock_help = 'help 123' | ||||
|         kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} | ||||
|         expected_name = 'foo-bar' | ||||
|         expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs | ||||
|         to_omit = ['abc', 'xyz'] | ||||
|         output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs) | ||||
|         self.assertEqual(mock_subparser, output) | ||||
|         mock_add_parser.assert_called_once_with(**expected_kwargs) | ||||
|         mock_subparser.add_function_args.assert_called_once_with(foo_bar, to_omit) | ||||
|  | ||||
|     @patch.object(parser, 'get_first_doc_line') | ||||
|     def test_add_property_command(self, mock_get_first_doc_line: MagicMock): | ||||
|         def get_prop(_self): pass | ||||
|         def set_prop(_self, _value): pass | ||||
|         prop = property(get_prop) | ||||
|         mock_subparser = MagicMock() | ||||
|         mock_add_parser = MagicMock(return_value=mock_subparser) | ||||
|         self.parser._commands = MagicMock(add_parser=mock_add_parser) | ||||
|         mock_get_first_doc_line.return_value = mock_help = 'help 123' | ||||
|         kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} | ||||
|         expected_name = 'get-prop' | ||||
|         expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs | ||||
|         output = self.parser.add_property_command(prop, **kwargs) | ||||
|         self.assertEqual(mock_subparser, output) | ||||
|         mock_get_first_doc_line.assert_called_once_with(get_prop) | ||||
|         mock_add_parser.assert_called_once_with(**expected_kwargs) | ||||
|         mock_subparser.add_function_arg.assert_not_called() | ||||
|  | ||||
|         mock_get_first_doc_line.reset_mock() | ||||
|         mock_add_parser.reset_mock() | ||||
|  | ||||
|         prop = property(get_prop, set_prop) | ||||
|         expected_help = f"Get/set the `.{expected_name}` property" | ||||
|         expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs | ||||
|         output = self.parser.add_property_command(prop, **kwargs) | ||||
|         self.assertEqual(mock_subparser, output) | ||||
|         mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)]) | ||||
|         mock_add_parser.assert_called_once_with(**expected_kwargs) | ||||
|         mock_subparser.add_function_arg.assert_called_once_with( | ||||
|             tuple(signature(set_prop).parameters.values())[1], | ||||
|             nargs='?', | ||||
|             default=SUPPRESS, | ||||
|             help=f"If provided: {mock_help} If omitted: {mock_help}" | ||||
|         ) | ||||
|  | ||||
|     @patch.object(parser.ControlParser, 'add_property_command') | ||||
|     @patch.object(parser.ControlParser, 'add_function_command') | ||||
|     def test_add_class_commands(self, mock_add_function_command: MagicMock, mock_add_property_command: MagicMock): | ||||
|         class FooBar: | ||||
|             some_attribute = None | ||||
|  | ||||
|             def _protected(self, _): pass | ||||
|  | ||||
|             def __private(self, _): pass | ||||
|  | ||||
|             def to_omit(self, _): pass | ||||
|  | ||||
|             def method(self, _): pass | ||||
|  | ||||
|             @property | ||||
|             def prop(self): return None | ||||
|  | ||||
|         mock_set_defaults = MagicMock() | ||||
|         mock_subparser = MagicMock(set_defaults=mock_set_defaults) | ||||
|         mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser | ||||
|         x = 'x' | ||||
|         common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer, | ||||
|                          parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width} | ||||
|         expected_output = {'method': mock_subparser, 'prop': mock_subparser} | ||||
|         output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x) | ||||
|         self.assertDictEqual(expected_output, output) | ||||
|         mock_add_function_command.assert_called_once_with(FooBar.method, **common_kwargs) | ||||
|         mock_add_property_command.assert_called_once_with(FooBar.prop, FooBar.__name__, **common_kwargs) | ||||
|         mock_set_defaults.assert_has_calls([call(**{x: FooBar.method}), call(**{x: FooBar.prop})]) | ||||
|  | ||||
|     @patch.object(parser.ArgumentParser, 'add_subparsers') | ||||
|     def test_add_subparsers(self, mock_base_add_subparsers: MagicMock): | ||||
|         args, kwargs = [1, 2, 42], {FOO: 123, BAR: 456} | ||||
|         mock_base_add_subparsers.return_value = mock_action = MagicMock() | ||||
|         output = self.parser.add_subparsers(*args, **kwargs) | ||||
|         self.assertEqual(mock_action, output) | ||||
|         mock_base_add_subparsers.assert_called_once_with(*args, **kwargs) | ||||
|  | ||||
|     def test__print_message(self): | ||||
|         self.stream_writer.write = MagicMock() | ||||
|         self.assertIsNone(self.parser._print_message('')) | ||||
|         self.stream_writer.write.assert_not_called() | ||||
|         msg = 'foo bar baz' | ||||
|         self.assertIsNone(self.parser._print_message(msg)) | ||||
|         self.stream_writer.write.assert_called_once_with(msg.encode()) | ||||
|  | ||||
|     @patch.object(parser.ControlParser, '_print_message') | ||||
|     def test_exit(self, mock__print_message: MagicMock): | ||||
|         self.assertIsNone(self.parser.exit(123, '')) | ||||
|         mock__print_message.assert_not_called() | ||||
|         msg = 'foo bar baz' | ||||
|         self.assertIsNone(self.parser.exit(123, msg)) | ||||
|         mock__print_message.assert_called_once_with(msg) | ||||
|  | ||||
|     @patch.object(parser.ArgumentParser, 'error') | ||||
|     def test_error(self, mock_supercls_error: MagicMock): | ||||
|         with self.assertRaises(ParserError): | ||||
|             self.parser.error(FOO + BAR) | ||||
|         mock_supercls_error.assert_called_once_with(message=FOO + BAR) | ||||
|  | ||||
|     @patch.object(parser.ArgumentParser, 'print_help') | ||||
|     def test_print_help(self, mock_print_help: MagicMock): | ||||
|         arg = MagicMock() | ||||
|         with self.assertRaises(HelpRequested): | ||||
|             self.parser.print_help(arg) | ||||
|         mock_print_help.assert_called_once_with(arg) | ||||
|  | ||||
|     @patch.object(parser, '_get_type_from_annotation') | ||||
|     @patch.object(parser.ArgumentParser, 'add_argument') | ||||
|     def test_add_function_arg(self, mock_add_argument: MagicMock, mock__get_type_from_annotation: MagicMock): | ||||
|         mock_add_argument.return_value = expected_output = 'action' | ||||
|         mock__get_type_from_annotation.return_value = mock_type = 'fake' | ||||
|  | ||||
|         foo_type, args_type, bar_type, baz_type, boo_type = tuple, str, int, float, complex | ||||
|         bar_default, baz_default, boo_default = 1, 0.1, 1j | ||||
|  | ||||
|         def func(foo: foo_type, *args: args_type, bar: bar_type = bar_default, baz: baz_type = baz_default, | ||||
|                  boo: boo_type = boo_default, flag: bool = False): | ||||
|             return foo, args, bar, baz, boo, flag | ||||
|  | ||||
|         param_foo, param_args, param_bar, param_baz, param_boo, param_flag = signature(func).parameters.values() | ||||
|         kwargs = {FOO + BAR: 'xyz'} | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_foo, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('foo', type=mock_type, **kwargs) | ||||
|         mock__get_type_from_annotation.assert_called_once_with(foo_type) | ||||
|  | ||||
|         mock_add_argument.reset_mock() | ||||
|         mock__get_type_from_annotation.reset_mock() | ||||
|  | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_args, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('args', nargs='*', type=mock_type, **kwargs) | ||||
|         mock__get_type_from_annotation.assert_called_once_with(args_type) | ||||
|  | ||||
|         mock_add_argument.reset_mock() | ||||
|         mock__get_type_from_annotation.reset_mock() | ||||
|  | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_bar, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('-b', '--bar', default=bar_default, type=mock_type, **kwargs) | ||||
|         mock__get_type_from_annotation.assert_called_once_with(bar_type) | ||||
|  | ||||
|         mock_add_argument.reset_mock() | ||||
|         mock__get_type_from_annotation.reset_mock() | ||||
|  | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_baz, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('-B', '--baz', default=baz_default, type=mock_type, **kwargs) | ||||
|         mock__get_type_from_annotation.assert_called_once_with(baz_type) | ||||
|  | ||||
|         mock_add_argument.reset_mock() | ||||
|         mock__get_type_from_annotation.reset_mock() | ||||
|  | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_boo, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('--boo', default=boo_default, type=mock_type, **kwargs) | ||||
|         mock__get_type_from_annotation.assert_called_once_with(boo_type) | ||||
|  | ||||
|         mock_add_argument.reset_mock() | ||||
|         mock__get_type_from_annotation.reset_mock() | ||||
|  | ||||
|         self.assertEqual(expected_output, self.parser.add_function_arg(param_flag, **kwargs)) | ||||
|         mock_add_argument.assert_called_once_with('-f', '--flag', action='store_true', **kwargs) | ||||
|         mock__get_type_from_annotation.assert_not_called() | ||||
|  | ||||
|     @patch.object(parser.ControlParser, 'add_function_arg') | ||||
|     def test_add_function_args(self, mock_add_function_arg: MagicMock): | ||||
|         def func(foo: str, *args: int, bar: float = 0.1): | ||||
|             return foo, args, bar | ||||
|         _, param_args, param_bar = signature(func).parameters.values() | ||||
|         self.assertIsNone(self.parser.add_function_args(func, omit=['foo'])) | ||||
|         mock_add_function_arg.assert_has_calls([ | ||||
|             call(param_args, help=repr(param_args.annotation)), | ||||
|             call(param_bar, help=repr(param_bar.annotation)), | ||||
|         ]) | ||||
|  | ||||
|  | ||||
| class RestTestCase(TestCase): | ||||
|     log_lvl: int | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.log_lvl = parser.log.level | ||||
|         parser.log.setLevel(999) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls) -> None: | ||||
|         parser.log.setLevel(cls.log_lvl) | ||||
|  | ||||
|     def test__get_arg_type_wrapper(self): | ||||
|         type_wrap = parser._get_arg_type_wrapper(int) | ||||
|         self.assertEqual('int', type_wrap.__name__) | ||||
|         self.assertEqual(SUPPRESS, type_wrap(SUPPRESS)) | ||||
|         self.assertEqual(13, type_wrap('13')) | ||||
|  | ||||
|         name = 'abcdef' | ||||
|         mock_type = MagicMock(side_effect=[parser.ArgumentTypeError, TypeError, ValueError, Exception], __name__=name) | ||||
|         type_wrap = parser._get_arg_type_wrapper(mock_type) | ||||
|         self.assertEqual(name, type_wrap.__name__) | ||||
|         with self.assertRaises(parser.ArgumentTypeError): | ||||
|             type_wrap(FOO) | ||||
|         with self.assertRaises(TypeError): | ||||
|             type_wrap(FOO) | ||||
|         with self.assertRaises(ValueError): | ||||
|             type_wrap(FOO) | ||||
|         with self.assertRaises(parser.ArgumentTypeError): | ||||
|             type_wrap(FOO) | ||||
|  | ||||
|     @patch.object(parser, '_get_arg_type_wrapper') | ||||
|     def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock): | ||||
|         mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR | ||||
|         dotted_path_ann = [CoroutineFunc, EndCB, CancelCB] | ||||
|         literal_eval_ann = [ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]] | ||||
|         any_other_ann = MagicMock() | ||||
|         for a in dotted_path_ann: | ||||
|             self.assertEqual(expected_output, parser._get_type_from_annotation(a)) | ||||
|         mock__get_arg_type_wrapper.assert_has_calls(len(dotted_path_ann) * [call(resolve_dotted_path)]) | ||||
|         mock__get_arg_type_wrapper.reset_mock() | ||||
|         for a in literal_eval_ann: | ||||
|             self.assertEqual(expected_output, parser._get_type_from_annotation(a)) | ||||
|         mock__get_arg_type_wrapper.assert_has_calls(len(literal_eval_ann) * [call(literal_eval)]) | ||||
|         mock__get_arg_type_wrapper.reset_mock() | ||||
|         self.assertEqual(expected_output, parser._get_type_from_annotation(any_other_ann)) | ||||
|         mock__get_arg_type_wrapper.assert_called_once_with(any_other_ann) | ||||
| @@ -21,12 +21,13 @@ Unittests for the `asyncio_taskpool.server` module. | ||||
| 
 | ||||
| import asyncio | ||||
| import logging | ||||
| import os | ||||
| from pathlib import Path | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest import IsolatedAsyncioTestCase, skipIf | ||||
| from unittest.mock import AsyncMock, MagicMock, patch | ||||
| 
 | ||||
| from asyncio_taskpool import server | ||||
| from asyncio_taskpool.client import ControlClient, UnixControlClient | ||||
| from asyncio_taskpool.control import server | ||||
| from asyncio_taskpool.control.client import ControlClient, TCPControlClient, UnixControlClient | ||||
| 
 | ||||
| 
 | ||||
| FOO, BAR = 'foo', 'bar' | ||||
| @@ -45,7 +46,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|         server.log.setLevel(cls.log_lvl) | ||||
| 
 | ||||
|     def setUp(self) -> None: | ||||
|         self.abstract_patcher = patch('asyncio_taskpool.server.ControlServer.__abstractmethods__', set()) | ||||
|         self.abstract_patcher = patch('asyncio_taskpool.control.server.ControlServer.__abstractmethods__', set()) | ||||
|         self.mock_abstract_methods = self.abstract_patcher.start() | ||||
|         self.mock_pool = MagicMock() | ||||
|         self.kwargs = {FOO: 123, BAR: 456} | ||||
| @@ -119,6 +120,51 @@ class ControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|         mock_create_task.assert_called_once_with(mock_awaitable) | ||||
| 
 | ||||
| 
 | ||||
| class TCPControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|     log_lvl: int | ||||
| 
 | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.log_lvl = server.log.level | ||||
|         server.log.setLevel(999) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def tearDownClass(cls) -> None: | ||||
|         server.log.setLevel(cls.log_lvl) | ||||
| 
 | ||||
|     def setUp(self) -> None: | ||||
|         self.base_init_patcher = patch.object(server.ControlServer, '__init__') | ||||
|         self.mock_base_init = self.base_init_patcher.start() | ||||
|         self.mock_pool = MagicMock() | ||||
|         self.host, self.port = 'localhost', 12345 | ||||
|         self.kwargs = {FOO: 123, BAR: 456} | ||||
|         self.server = server.TCPControlServer(pool=self.mock_pool, host=self.host, port=self.port, **self.kwargs) | ||||
| 
 | ||||
|     def tearDown(self) -> None: | ||||
|         self.base_init_patcher.stop() | ||||
| 
 | ||||
|     def test__client_class(self): | ||||
|         self.assertEqual(TCPControlClient, self.server._client_class) | ||||
| 
 | ||||
|     def test_init(self): | ||||
|         self.assertEqual(self.host, self.server._host) | ||||
|         self.assertEqual(self.port, self.server._port) | ||||
|         self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs) | ||||
| 
 | ||||
|     @patch.object(server, 'start_server') | ||||
|     async def test__get_server_instance(self, mock_start_server: AsyncMock): | ||||
|         mock_start_server.return_value = expected_output = 'totally_a_server' | ||||
|         mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2} | ||||
|         args = [mock_callback] | ||||
|         output = await self.server._get_server_instance(*args, **mock_kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_start_server.assert_called_once_with(mock_callback, self.host, self.port, **mock_kwargs) | ||||
| 
 | ||||
|     def test__final_callback(self): | ||||
|         self.assertIsNone(self.server._final_callback()) | ||||
| 
 | ||||
| 
 | ||||
| @skipIf(os.name == 'nt', "No Unix sockets on Windows :(") | ||||
| class UnixControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|     log_lvl: int | ||||
| 
 | ||||
| @@ -137,7 +183,7 @@ class UnixControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|         self.mock_pool = MagicMock() | ||||
|         self.path = '/tmp/asyncio_taskpool' | ||||
|         self.kwargs = {FOO: 123, BAR: 456} | ||||
|         self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs) | ||||
|         self.server = server.UnixControlServer(pool=self.mock_pool, socket_path=self.path, **self.kwargs) | ||||
| 
 | ||||
|     def tearDown(self) -> None: | ||||
|         self.base_init_patcher.stop() | ||||
| @@ -149,9 +195,9 @@ class UnixControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|         self.assertEqual(Path(self.path), self.server._socket_path) | ||||
|         self.mock_base_init.assert_called_once_with(self.mock_pool, **self.kwargs) | ||||
| 
 | ||||
|     @patch.object(server, 'start_unix_server') | ||||
|     async def test__get_server_instance(self, mock_start_unix_server: AsyncMock): | ||||
|         mock_start_unix_server.return_value = expected_output = 'totally_a_server' | ||||
|     async def test__get_server_instance(self): | ||||
|         expected_output = 'totally_a_server' | ||||
|         self.server._start_unix_server = mock_start_unix_server = AsyncMock(return_value=expected_output) | ||||
|         mock_callback, mock_kwargs = MagicMock(), {'a': 1, 'b': 2} | ||||
|         args = [mock_callback] | ||||
|         output = await self.server._get_server_instance(*args, **mock_kwargs) | ||||
							
								
								
									
										207
									
								
								tests/test_control/test_session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								tests/test_control/test_session.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,207 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.session` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import json | ||||
| from argparse import ArgumentError, Namespace | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import AsyncMock, MagicMock, patch, call | ||||
|  | ||||
| from asyncio_taskpool.control import session | ||||
| from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER | ||||
| from asyncio_taskpool.exceptions import HelpRequested | ||||
| from asyncio_taskpool.pool import SimpleTaskPool | ||||
|  | ||||
|  | ||||
| FOO, BAR = 'foo', 'bar' | ||||
|  | ||||
|  | ||||
| class ControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|     log_lvl: int | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.log_lvl = session.log.level | ||||
|         session.log.setLevel(999) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls) -> None: | ||||
|         session.log.setLevel(cls.log_lvl) | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock())) | ||||
|         self.mock_client_class_name = FOO + BAR | ||||
|         self.mock_server = MagicMock(pool=self.mock_pool, | ||||
|                                      client_class_name=self.mock_client_class_name) | ||||
|         self.mock_reader = MagicMock() | ||||
|         self.mock_writer = MagicMock() | ||||
|         self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer) | ||||
|  | ||||
|     def test_init(self): | ||||
|         self.assertEqual(self.mock_server, self.session._control_server) | ||||
|         self.assertEqual(self.mock_pool, self.session._pool) | ||||
|         self.assertEqual(self.mock_client_class_name, self.session._client_class_name) | ||||
|         self.assertEqual(self.mock_reader, self.session._reader) | ||||
|         self.assertEqual(self.mock_writer, self.session._writer) | ||||
|         self.assertIsNone(self.session._parser) | ||||
|  | ||||
|     @patch.object(session, 'return_or_exception') | ||||
|     async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock): | ||||
|         def method(self, arg1, arg2, *var_args, **rest): pass | ||||
|         test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11} | ||||
|         kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest | ||||
|         mock_return_or_exception.return_value = None | ||||
|         self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs)) | ||||
|         mock_return_or_exception.assert_awaited_once_with( | ||||
|             method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest | ||||
|         ) | ||||
|         self.mock_writer.write.assert_called_once_with(session.CMD_OK) | ||||
|  | ||||
|     @patch.object(session, 'return_or_exception') | ||||
|     async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock): | ||||
|         def prop_get(_): pass | ||||
|         def prop_set(_): pass | ||||
|         prop = property(prop_get, prop_set) | ||||
|         kwargs = {'value': 'something'} | ||||
|         mock_return_or_exception.return_value = None | ||||
|         self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs)) | ||||
|         mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs) | ||||
|         self.mock_writer.write.assert_called_once_with(session.CMD_OK) | ||||
|  | ||||
|         mock_return_or_exception.reset_mock() | ||||
|         self.mock_writer.write.reset_mock() | ||||
|  | ||||
|         mock_return_or_exception.return_value = val = 420.69 | ||||
|         self.assertIsNone(await self.session._exec_property_and_respond(prop)) | ||||
|         mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool) | ||||
|         self.mock_writer.write.assert_called_once_with(str(val).encode()) | ||||
|  | ||||
|     @patch.object(session, 'ControlParser') | ||||
|     async def test_client_handshake(self, mock_parser_cls: MagicMock): | ||||
|         mock_add_subparsers, mock_add_class_commands = MagicMock(), MagicMock() | ||||
|         mock_parser = MagicMock(add_subparsers=mock_add_subparsers, add_class_commands=mock_add_class_commands) | ||||
|         mock_parser_cls.return_value = mock_parser | ||||
|         width = 5678 | ||||
|         msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + '  ' | ||||
|         mock_read = AsyncMock(return_value=msg.encode()) | ||||
|         self.mock_reader.read = mock_read | ||||
|         self.mock_writer.drain = AsyncMock() | ||||
|         expected_parser_kwargs = { | ||||
|             STREAM_WRITER: self.mock_writer, | ||||
|             CLIENT_INFO.TERMINAL_WIDTH: width, | ||||
|             'prog': '', | ||||
|             'usage': f'[-h] [{CMD}] ...' | ||||
|         } | ||||
|         expected_subparsers_kwargs = { | ||||
|             'title': "Commands", | ||||
|             'metavar': "(A command followed by '-h' or '--help' will show command-specific help.)" | ||||
|         } | ||||
|         self.assertIsNone(await self.session.client_handshake()) | ||||
|         self.assertEqual(mock_parser, self.session._parser) | ||||
|         mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) | ||||
|         mock_parser_cls.assert_called_once_with(**expected_parser_kwargs) | ||||
|         mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs) | ||||
|         mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__) | ||||
|         self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) | ||||
|         self.mock_writer.drain.assert_awaited_once_with() | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_exec_property_and_respond') | ||||
|     @patch.object(session.ControlSession, '_exec_method_and_respond') | ||||
|     async def test__parse_command(self, mock__exec_method_and_respond: AsyncMock, | ||||
|                                   mock__exec_property_and_respond: AsyncMock): | ||||
|         def method(_): pass | ||||
|         prop = property(method) | ||||
|         msg = 'asdf asd as a' | ||||
|         kwargs = {FOO: BAR, 'hello': 'python'} | ||||
|         mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs)) | ||||
|         self.session._parser = MagicMock(parse_args=mock_parse_args) | ||||
|         self.mock_writer.write = MagicMock() | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_not_called() | ||||
|         mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs) | ||||
|         mock__exec_property_and_respond.assert_not_called() | ||||
|  | ||||
|         mock__exec_method_and_respond.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|  | ||||
|         mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs) | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_not_called() | ||||
|         mock__exec_method_and_respond.assert_not_called() | ||||
|         mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs) | ||||
|  | ||||
|         mock__exec_property_and_respond.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|  | ||||
|         bad_command = 'definitely not a function or property' | ||||
|         mock_parse_args.return_value = Namespace(**{CMD: bad_command}, **kwargs) | ||||
|         with patch.object(session, 'CommandError') as cmd_err_cls: | ||||
|             cmd_err_cls.return_value = exc = MagicMock() | ||||
|             self.assertIsNone(await self.session._parse_command(msg)) | ||||
|             cmd_err_cls.assert_called_once_with(f"Unknown command object: {bad_command}") | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         mock__exec_method_and_respond.assert_not_called() | ||||
|         mock__exec_property_and_respond.assert_not_called() | ||||
|         self.mock_writer.write.assert_called_once_with(str(exc).encode()) | ||||
|  | ||||
|         mock__exec_property_and_respond.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|         self.mock_writer.write.reset_mock() | ||||
|  | ||||
|         mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_called_once_with(str(exc).encode()) | ||||
|         mock__exec_method_and_respond.assert_not_awaited() | ||||
|         mock__exec_property_and_respond.assert_not_awaited() | ||||
|  | ||||
|         self.mock_writer.write.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|  | ||||
|         mock_parse_args.side_effect = HelpRequested() | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_not_called() | ||||
|         mock__exec_method_and_respond.assert_not_awaited() | ||||
|         mock__exec_property_and_respond.assert_not_awaited() | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_parse_command') | ||||
|     async def test_listen(self, mock__parse_command: AsyncMock): | ||||
|         def make_reader_return_empty(): | ||||
|             self.mock_reader.read.return_value = b'' | ||||
|         self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) | ||||
|         msg = "fascinating" | ||||
|         self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) | ||||
|         self.assertIsNone(await self.session.listen()) | ||||
|         self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) | ||||
|         mock__parse_command.assert_awaited_once_with(msg) | ||||
|         self.mock_writer.drain.assert_awaited_once_with() | ||||
|  | ||||
|         self.mock_reader.read.reset_mock() | ||||
|         mock__parse_command.reset_mock() | ||||
|         self.mock_writer.drain.reset_mock() | ||||
|  | ||||
|         self.mock_server.is_serving = MagicMock(return_value=False) | ||||
|         self.assertIsNone(await self.session.listen()) | ||||
|         self.mock_reader.read.assert_not_awaited() | ||||
|         mock__parse_command.assert_not_awaited() | ||||
|         self.mock_writer.drain.assert_not_awaited() | ||||
							
								
								
									
										0
									
								
								tests/test_internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/test_internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										84
									
								
								tests/test_internals/test_group_register.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tests/test_internals/test_group_register.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.group_register` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from asyncio.locks import Lock | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| from asyncio_taskpool.internals import group_register | ||||
|  | ||||
| FOO, BAR = 'foo', 'bar' | ||||
|  | ||||
|  | ||||
| class TaskGroupRegisterTestCase(IsolatedAsyncioTestCase): | ||||
|     def setUp(self) -> None: | ||||
|         self.reg = group_register.TaskGroupRegister() | ||||
|  | ||||
|     def test_init(self): | ||||
|         ids = [FOO, BAR, 1, 2] | ||||
|         reg = group_register.TaskGroupRegister(*ids) | ||||
|         self.assertSetEqual(set(ids), reg._ids) | ||||
|         self.assertIsInstance(reg._lock, Lock) | ||||
|  | ||||
|     def test___contains__(self): | ||||
|         self.reg._ids = {1, 2, 3} | ||||
|         for i in self.reg._ids: | ||||
|             self.assertTrue(i in self.reg) | ||||
|         self.assertFalse(4 in self.reg) | ||||
|  | ||||
|     @patch.object(group_register, 'iter', return_value=FOO) | ||||
|     def test___iter__(self, mock_iter: MagicMock): | ||||
|         self.assertEqual(FOO, self.reg.__iter__()) | ||||
|         mock_iter.assert_called_once_with(self.reg._ids) | ||||
|  | ||||
|     def test___len__(self): | ||||
|         self.reg._ids = [1, 2, 3, 4] | ||||
|         self.assertEqual(4, len(self.reg)) | ||||
|  | ||||
|     def test_add(self): | ||||
|         self.assertSetEqual(set(), self.reg._ids) | ||||
|         self.assertIsNone(self.reg.add(123)) | ||||
|         self.assertSetEqual({123}, self.reg._ids) | ||||
|  | ||||
|     def test_discard(self): | ||||
|         self.reg._ids = {123} | ||||
|         self.assertIsNone(self.reg.discard(0)) | ||||
|         self.assertIsNone(self.reg.discard(999)) | ||||
|         self.assertIsNone(self.reg.discard(123)) | ||||
|         self.assertSetEqual(set(), self.reg._ids) | ||||
|  | ||||
|     async def test_acquire(self): | ||||
|         self.assertFalse(self.reg._lock.locked()) | ||||
|         await self.reg.acquire() | ||||
|         self.assertTrue(self.reg._lock.locked()) | ||||
|  | ||||
|     def test_release(self): | ||||
|         self.reg._lock._locked = True | ||||
|         self.assertTrue(self.reg._lock.locked()) | ||||
|         self.reg.release() | ||||
|         self.assertFalse(self.reg._lock.locked()) | ||||
|  | ||||
|     async def test_contextmanager(self): | ||||
|         self.assertFalse(self.reg._lock.locked()) | ||||
|         async with self.reg as nothing: | ||||
|             self.assertIsNone(nothing) | ||||
|             self.assertTrue(self.reg._lock.locked()) | ||||
|         self.assertFalse(self.reg._lock.locked()) | ||||
| @@ -20,9 +20,9 @@ Unittests for the `asyncio_taskpool.helpers` module. | ||||
| 
 | ||||
| 
 | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock | ||||
| from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch | ||||
| 
 | ||||
| from asyncio_taskpool import helpers | ||||
| from asyncio_taskpool.internals import helpers | ||||
| 
 | ||||
| 
 | ||||
| class HelpersTestCase(IsolatedAsyncioTestCase): | ||||
| @@ -81,20 +81,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase): | ||||
|         with self.assertRaises(ValueError): | ||||
|             helpers.star_function(f, a, 123456789) | ||||
| 
 | ||||
|     async def test_join_queue(self): | ||||
|         mock_join = AsyncMock() | ||||
|         mock_queue = MagicMock(join=mock_join) | ||||
|         self.assertIsNone(await helpers.join_queue(mock_queue)) | ||||
|         mock_join.assert_awaited_once_with() | ||||
| 
 | ||||
|     def test_task_str(self): | ||||
|         self.assertEqual("task", helpers.tasks_str(1)) | ||||
|         self.assertEqual("tasks", helpers.tasks_str(0)) | ||||
|         self.assertEqual("tasks", helpers.tasks_str(-1)) | ||||
|         self.assertEqual("tasks", helpers.tasks_str(2)) | ||||
|         self.assertEqual("tasks", helpers.tasks_str(-10)) | ||||
|         self.assertEqual("tasks", helpers.tasks_str(42)) | ||||
| 
 | ||||
|     def test_get_first_doc_line(self): | ||||
|         expected_output = 'foo bar baz' | ||||
|         mock_obj = MagicMock(__doc__=f"""{expected_output}  | ||||
| @@ -126,3 +112,13 @@ class HelpersTestCase(IsolatedAsyncioTestCase): | ||||
|         output = await helpers.return_or_exception(mock_func, *args, **kwargs) | ||||
|         self.assertEqual(test_exception, output) | ||||
|         mock_func.assert_called_once_with(*args, **kwargs) | ||||
| 
 | ||||
|     def test_resolve_dotted_path(self): | ||||
|         from logging import WARNING | ||||
|         from urllib.request import urlopen | ||||
|         self.assertEqual(WARNING, helpers.resolve_dotted_path('logging.WARNING')) | ||||
|         self.assertEqual(urlopen, helpers.resolve_dotted_path('urllib.request.urlopen')) | ||||
|         with patch.object(helpers, 'import_module', return_value=object) as mock_import_module: | ||||
|             with self.assertRaises(AttributeError): | ||||
|                 helpers.resolve_dotted_path('foo.bar.baz') | ||||
|             mock_import_module.assert_has_calls([call('foo'), call('foo.bar')]) | ||||
| @@ -18,10 +18,8 @@ __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.pool` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import asyncio | ||||
| from asyncio.exceptions import CancelledError | ||||
| from asyncio.queues import Queue | ||||
| from asyncio.locks import Semaphore | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call | ||||
| from typing import Type | ||||
| @@ -29,8 +27,8 @@ from typing import Type | ||||
| from asyncio_taskpool import pool, exceptions | ||||
|  | ||||
|  | ||||
| EMPTY_LIST, EMPTY_DICT = [], {} | ||||
| FOO, BAR = 'foo', 'bar' | ||||
| EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set() | ||||
| FOO, BAR, BAZ = 'foo', 'bar', 'baz' | ||||
|  | ||||
|  | ||||
| class TestException(Exception): | ||||
| @@ -45,19 +43,12 @@ class CommonTestCase(IsolatedAsyncioTestCase): | ||||
|     task_pool: pool.BaseTaskPool | ||||
|     log_lvl: int | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.log_lvl = pool.log.level | ||||
|         pool.log.setLevel(999) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls) -> None: | ||||
|         pool.log.setLevel(cls.log_lvl) | ||||
|  | ||||
|     def get_task_pool_init_params(self) -> dict: | ||||
|         return {'pool_size': self.TEST_POOL_SIZE, 'name': self.TEST_POOL_NAME} | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.log_lvl = pool.log.level | ||||
|         pool.log.setLevel(999) | ||||
|         self._pools = self.TEST_CLASS._pools | ||||
|         # These three methods are called during initialization, so we mock them by default during setup: | ||||
|         self._add_pool_patcher = patch.object(self.TEST_CLASS, '_add_pool') | ||||
| @@ -76,6 +67,7 @@ class CommonTestCase(IsolatedAsyncioTestCase): | ||||
|         self._add_pool_patcher.stop() | ||||
|         self.pool_size_patcher.stop() | ||||
|         self.dunder_str_patcher.stop() | ||||
|         pool.log.setLevel(self.log_lvl) | ||||
|  | ||||
|  | ||||
| class BaseTaskPoolTestCase(CommonTestCase): | ||||
| @@ -88,19 +80,24 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|         self.assertListEqual([self.task_pool], pool.BaseTaskPool._pools) | ||||
|  | ||||
|     def test_init(self): | ||||
|         self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore) | ||||
|         self.assertEqual(0, self.task_pool._num_started) | ||||
|  | ||||
|         self.assertFalse(self.task_pool._locked) | ||||
|         self.assertEqual(0, self.task_pool._counter) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._running) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._cancelled) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._ended) | ||||
|         self.assertEqual(0, self.task_pool._num_cancelled) | ||||
|         self.assertEqual(0, self.task_pool._num_ended) | ||||
|         self.assertEqual(self.mock_idx, self.task_pool._idx) | ||||
|         self.assertFalse(self.task_pool._closed) | ||||
|         self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) | ||||
|         self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) | ||||
|         self.assertIsInstance(self.task_pool._interrupt_flag, asyncio.locks.Event) | ||||
|         self.assertFalse(self.task_pool._interrupt_flag.is_set()) | ||||
|  | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) | ||||
|  | ||||
|         self.assertIsInstance(self.task_pool._enough_room, Semaphore) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) | ||||
|  | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) | ||||
|         self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) | ||||
|  | ||||
|         self.assertEqual(self.mock_idx, self.task_pool._idx) | ||||
|  | ||||
|         self.mock__add_pool.assert_called_once_with(self.task_pool) | ||||
|         self.mock_pool_size.assert_called_once_with(self.TEST_POOL_SIZE) | ||||
|         self.mock___str__.assert_called_once_with() | ||||
| @@ -115,14 +112,14 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|  | ||||
|     def test_pool_size(self): | ||||
|         self.pool_size_patcher.stop() | ||||
|         self.task_pool._pool_size = self.TEST_POOL_SIZE | ||||
|         self.task_pool._enough_room._value = self.TEST_POOL_SIZE | ||||
|         self.assertEqual(self.TEST_POOL_SIZE, self.task_pool.pool_size) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.task_pool.pool_size = -1 | ||||
|  | ||||
|         self.task_pool.pool_size = new_size = 69 | ||||
|         self.assertEqual(new_size, self.task_pool._pool_size) | ||||
|         self.assertEqual(new_size, self.task_pool._enough_room._value) | ||||
|  | ||||
|     def test_is_locked(self): | ||||
|         self.task_pool._locked = FOO | ||||
| @@ -143,26 +140,49 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|         self.assertFalse(self.task_pool._locked) | ||||
|  | ||||
|     def test_num_running(self): | ||||
|         self.task_pool._running = ['foo', 'bar', 'baz'] | ||||
|         self.task_pool._tasks_running = {1: FOO, 2: BAR, 3: BAZ} | ||||
|         self.assertEqual(3, self.task_pool.num_running) | ||||
|  | ||||
|     def test_num_cancelled(self): | ||||
|         self.task_pool._num_cancelled = 3 | ||||
|         self.task_pool._tasks_cancelled = {1: FOO, 2: BAR, 3: BAZ} | ||||
|         self.assertEqual(3, self.task_pool.num_cancelled) | ||||
|  | ||||
|     def test_num_ended(self): | ||||
|         self.task_pool._num_ended = 3 | ||||
|         self.task_pool._tasks_ended = {1: FOO, 2: BAR, 3: BAZ} | ||||
|         self.assertEqual(3, self.task_pool.num_ended) | ||||
|  | ||||
|     def test_num_finished(self): | ||||
|         self.task_pool._num_cancelled = cancelled = 69 | ||||
|         self.task_pool._num_ended = ended = 420 | ||||
|         self.task_pool._cancelled = mock_cancelled_dict = {1: 'foo', 2: 'bar'} | ||||
|         self.assertEqual(ended - cancelled + len(mock_cancelled_dict), self.task_pool.num_finished) | ||||
|  | ||||
|     def test_is_full(self): | ||||
|         self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full) | ||||
|  | ||||
|     def test_get_group_ids(self): | ||||
|         group_name, ids = 'abcdef', [1, 2, 3] | ||||
|         self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids)) | ||||
|         self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name)) | ||||
|         with self.assertRaises(exceptions.InvalidGroupName): | ||||
|             self.task_pool.get_group_ids(group_name, 'something else') | ||||
|  | ||||
|     async def test__check_start(self): | ||||
|         self.task_pool._closed = True | ||||
|         mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() | ||||
|         try: | ||||
|             with self.assertRaises(AssertionError): | ||||
|                 self.task_pool._check_start(awaitable=None, function=None) | ||||
|             with self.assertRaises(AssertionError): | ||||
|                 self.task_pool._check_start(awaitable=mock_coroutine, function=mock_coroutine_function) | ||||
|             with self.assertRaises(exceptions.NotCoroutine): | ||||
|                 self.task_pool._check_start(awaitable=mock_coroutine_function, function=None) | ||||
|             with self.assertRaises(exceptions.NotCoroutine): | ||||
|                 self.task_pool._check_start(awaitable=None, function=mock_coroutine) | ||||
|             with self.assertRaises(exceptions.PoolIsClosed): | ||||
|                 self.task_pool._check_start(awaitable=mock_coroutine, function=None) | ||||
|             self.task_pool._closed = False | ||||
|             self.task_pool._locked = True | ||||
|             with self.assertRaises(exceptions.PoolIsLocked): | ||||
|                 self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) | ||||
|             self.assertIsNone(self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=True)) | ||||
|         finally: | ||||
|             await mock_coroutine | ||||
|  | ||||
|     def test__task_name(self): | ||||
|         i = 123 | ||||
|         self.assertEqual(f'{self.mock_str}_Task-{i}', self.task_pool._task_name(i)) | ||||
| @@ -171,12 +191,10 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|     @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) | ||||
|     async def test__task_cancellation(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): | ||||
|         task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() | ||||
|         self.task_pool._num_cancelled = cancelled = 3 | ||||
|         self.task_pool._running[task_id] = mock_task | ||||
|         self.task_pool._tasks_running[task_id] = mock_task | ||||
|         self.assertIsNone(await self.task_pool._task_cancellation(task_id, mock_callback)) | ||||
|         self.assertNotIn(task_id, self.task_pool._running) | ||||
|         self.assertEqual(mock_task, self.task_pool._cancelled[task_id]) | ||||
|         self.assertEqual(cancelled + 1, self.task_pool._num_cancelled) | ||||
|         self.assertNotIn(task_id, self.task_pool._tasks_running) | ||||
|         self.assertEqual(mock_task, self.task_pool._tasks_cancelled[task_id]) | ||||
|         mock__task_name.assert_called_with(task_id) | ||||
|         mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) | ||||
|  | ||||
| @@ -184,15 +202,13 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|     @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) | ||||
|     async def test__task_ending(self, mock__task_name: MagicMock, mock_execute_optional: AsyncMock): | ||||
|         task_id, mock_task, mock_callback = 1, MagicMock(), MagicMock() | ||||
|         self.task_pool._num_ended = ended = 3 | ||||
|         self.task_pool._enough_room._value = room = 123 | ||||
|  | ||||
|         # End running task: | ||||
|         self.task_pool._running[task_id] = mock_task | ||||
|         self.task_pool._tasks_running[task_id] = mock_task | ||||
|         self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) | ||||
|         self.assertNotIn(task_id, self.task_pool._running) | ||||
|         self.assertEqual(mock_task, self.task_pool._ended[task_id]) | ||||
|         self.assertEqual(ended + 1, self.task_pool._num_ended) | ||||
|         self.assertNotIn(task_id, self.task_pool._tasks_running) | ||||
|         self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id]) | ||||
|         self.assertEqual(room + 1, self.task_pool._enough_room._value) | ||||
|         mock__task_name.assert_called_with(task_id) | ||||
|         mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) | ||||
| @@ -200,11 +216,10 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|         mock_execute_optional.reset_mock() | ||||
|  | ||||
|         # End cancelled task: | ||||
|         self.task_pool._cancelled[task_id] = self.task_pool._ended.pop(task_id) | ||||
|         self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_ended.pop(task_id) | ||||
|         self.assertIsNone(await self.task_pool._task_ending(task_id, mock_callback)) | ||||
|         self.assertNotIn(task_id, self.task_pool._cancelled) | ||||
|         self.assertEqual(mock_task, self.task_pool._ended[task_id]) | ||||
|         self.assertEqual(ended + 2, self.task_pool._num_ended) | ||||
|         self.assertNotIn(task_id, self.task_pool._tasks_cancelled) | ||||
|         self.assertEqual(mock_task, self.task_pool._tasks_ended[task_id]) | ||||
|         self.assertEqual(room + 2, self.task_pool._enough_room._value) | ||||
|         mock__task_name.assert_called_with(task_id) | ||||
|         mock_execute_optional.assert_awaited_once_with(mock_callback, args=(task_id, )) | ||||
| @@ -246,92 +261,52 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool.BaseTaskPool, '_task_wrapper', new_callable=MagicMock) | ||||
|     @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) | ||||
|     async def test__start_task(self, mock__task_name: MagicMock, mock__task_wrapper: AsyncMock, | ||||
|                                mock_create_task: MagicMock): | ||||
|         def reset_mocks() -> None: | ||||
|             mock__task_name.reset_mock() | ||||
|             mock__task_wrapper.reset_mock() | ||||
|             mock_create_task.reset_mock() | ||||
|  | ||||
|     @patch.object(pool, 'TaskGroupRegister') | ||||
|     @patch.object(pool.BaseTaskPool, '_check_start') | ||||
|     async def test__start_task(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__task_name: MagicMock, | ||||
|                                mock__task_wrapper: AsyncMock, mock_create_task: MagicMock): | ||||
|         mock_group_reg = set_up_mock_group_register(mock_reg_cls) | ||||
|         mock_create_task.return_value = mock_task = MagicMock() | ||||
|         mock__task_wrapper.return_value = mock_wrapped = MagicMock() | ||||
|         mock_coroutine, mock_cancel_cb, mock_end_cb = AsyncMock(), MagicMock(), MagicMock() | ||||
|         self.task_pool._counter = count = 123 | ||||
|         mock_coroutine, mock_cancel_cb, mock_end_cb = MagicMock(), MagicMock(), MagicMock() | ||||
|         self.task_pool._num_started = count = 123 | ||||
|         self.task_pool._enough_room._value = room = 123 | ||||
|  | ||||
|         def check_nothing_changed() -> None: | ||||
|             self.assertEqual(count, self.task_pool._counter) | ||||
|             self.assertNotIn(count, self.task_pool._running) | ||||
|             self.assertEqual(room, self.task_pool._enough_room._value) | ||||
|             mock__task_name.assert_not_called() | ||||
|             mock__task_wrapper.assert_not_called() | ||||
|             mock_create_task.assert_not_called() | ||||
|             reset_mocks() | ||||
|  | ||||
|         with self.assertRaises(exceptions.NotCoroutine): | ||||
|             await self.task_pool._start_task(MagicMock(), end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) | ||||
|         check_nothing_changed() | ||||
|  | ||||
|         self.task_pool._locked = True | ||||
|         ignore_closed = False | ||||
|         mock_awaitable = mock_coroutine() | ||||
|         with self.assertRaises(exceptions.PoolIsLocked): | ||||
|             await self.task_pool._start_task(mock_awaitable, ignore_closed, | ||||
|                                              end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) | ||||
|         await mock_awaitable | ||||
|         check_nothing_changed() | ||||
|  | ||||
|         ignore_closed = True | ||||
|         mock_awaitable = mock_coroutine() | ||||
|         output = await self.task_pool._start_task(mock_awaitable, ignore_closed, | ||||
|         group_name, ignore_lock = 'testgroup', True | ||||
|         output = await self.task_pool._start_task(mock_coroutine, group_name=group_name, ignore_lock=ignore_lock, | ||||
|                                                   end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) | ||||
|         await mock_awaitable | ||||
|         self.assertEqual(count, output) | ||||
|         self.assertEqual(count + 1, self.task_pool._counter) | ||||
|         self.assertEqual(mock_task, self.task_pool._running[count]) | ||||
|         mock__check_start.assert_called_once_with(awaitable=mock_coroutine, ignore_lock=ignore_lock) | ||||
|         self.assertEqual(room - 1, self.task_pool._enough_room._value) | ||||
|         self.assertEqual(mock_group_reg, self.task_pool._task_groups[group_name]) | ||||
|         mock_reg_cls.assert_called_once_with() | ||||
|         mock_group_reg.__aenter__.assert_awaited_once_with() | ||||
|         mock_group_reg.add.assert_called_once_with(count) | ||||
|         mock__task_name.assert_called_once_with(count) | ||||
|         mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) | ||||
|         mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) | ||||
|         reset_mocks() | ||||
|         self.task_pool._counter = count | ||||
|         self.task_pool._enough_room._value = room | ||||
|         del self.task_pool._running[count] | ||||
|  | ||||
|         mock_awaitable = mock_coroutine() | ||||
|         mock_create_task.side_effect = test_exception = TestException() | ||||
|         with self.assertRaises(TestException) as e: | ||||
|             await self.task_pool._start_task(mock_awaitable, ignore_closed, | ||||
|                                              end_callback=mock_end_cb, cancel_callback=mock_cancel_cb) | ||||
|             self.assertEqual(test_exception, e) | ||||
|         await mock_awaitable | ||||
|         self.assertEqual(count + 1, self.task_pool._counter) | ||||
|         self.assertNotIn(count, self.task_pool._running) | ||||
|         self.assertEqual(room, self.task_pool._enough_room._value) | ||||
|         mock__task_name.assert_called_once_with(count) | ||||
|         mock__task_wrapper.assert_called_once_with(mock_awaitable, count, mock_end_cb, mock_cancel_cb) | ||||
|         mock_create_task.assert_called_once_with(mock_wrapped, name=FOO) | ||||
|         mock__task_wrapper.assert_called_once_with(mock_coroutine, count, mock_end_cb, mock_cancel_cb) | ||||
|         mock_create_task.assert_called_once_with(coro=mock_wrapped, name=FOO) | ||||
|         self.assertEqual(mock_task, self.task_pool._tasks_running[count]) | ||||
|         mock_group_reg.__aexit__.assert_awaited_once() | ||||
|  | ||||
|     @patch.object(pool.BaseTaskPool, '_task_name', return_value=FOO) | ||||
|     def test__get_running_task(self, mock__task_name: MagicMock): | ||||
|         task_id, mock_task = 555, MagicMock() | ||||
|         self.task_pool._running[task_id] = mock_task | ||||
|         self.task_pool._tasks_running[task_id] = mock_task | ||||
|         output = self.task_pool._get_running_task(task_id) | ||||
|         self.assertEqual(mock_task, output) | ||||
|  | ||||
|         self.task_pool._cancelled[task_id] = self.task_pool._running.pop(task_id) | ||||
|         self.task_pool._tasks_cancelled[task_id] = self.task_pool._tasks_running.pop(task_id) | ||||
|         with self.assertRaises(exceptions.AlreadyCancelled): | ||||
|             self.task_pool._get_running_task(task_id) | ||||
|         mock__task_name.assert_called_once_with(task_id) | ||||
|         mock__task_name.reset_mock() | ||||
|  | ||||
|         self.task_pool._ended[task_id] = self.task_pool._cancelled.pop(task_id) | ||||
|         self.task_pool._tasks_ended[task_id] = self.task_pool._tasks_cancelled.pop(task_id) | ||||
|         with self.assertRaises(exceptions.TaskEnded): | ||||
|             self.task_pool._get_running_task(task_id) | ||||
|         mock__task_name.assert_called_once_with(task_id) | ||||
|         mock__task_name.reset_mock() | ||||
|  | ||||
|         del self.task_pool._ended[task_id] | ||||
|         del self.task_pool._tasks_ended[task_id] | ||||
|         with self.assertRaises(exceptions.InvalidTaskID): | ||||
|             self.task_pool._get_running_task(task_id) | ||||
|         mock__task_name.assert_not_called() | ||||
| @@ -344,263 +319,354 @@ class BaseTaskPoolTestCase(CommonTestCase): | ||||
|         mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) | ||||
|         mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) | ||||
|  | ||||
|     def test_cancel_all(self): | ||||
|     def test__cancel_group_meta_tasks(self): | ||||
|         mock_task1, mock_task2 = MagicMock(), MagicMock() | ||||
|         self.task_pool._running = {1: mock_task1, 2: mock_task2} | ||||
|         assert not self.task_pool._interrupt_flag.is_set() | ||||
|         self.assertIsNone(self.task_pool.cancel_all(FOO)) | ||||
|         self.assertTrue(self.task_pool._interrupt_flag.is_set()) | ||||
|         mock_task1.cancel.assert_called_once_with(msg=FOO) | ||||
|         mock_task2.cancel.assert_called_once_with(msg=FOO) | ||||
|         self.task_pool._group_meta_tasks_running[BAR] = {mock_task1, mock_task2} | ||||
|         self.assertIsNone(self.task_pool._cancel_group_meta_tasks(FOO)) | ||||
|         self.assertDictEqual({BAR: {mock_task1, mock_task2}}, self.task_pool._group_meta_tasks_running) | ||||
|         self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) | ||||
|         mock_task1.cancel.assert_not_called() | ||||
|         mock_task2.cancel.assert_not_called() | ||||
|  | ||||
|     async def test_flush(self): | ||||
|         test_exception = TestException() | ||||
|         mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) | ||||
|         self.task_pool._ended = {123: mock_ended_func()} | ||||
|         self.task_pool._cancelled = {456: mock_cancelled_func()} | ||||
|         self.task_pool._interrupt_flag.set() | ||||
|         output = await self.task_pool.flush(return_exceptions=True) | ||||
|         self.assertListEqual([FOO, test_exception], output) | ||||
|         self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) | ||||
|         self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) | ||||
|         self.assertFalse(self.task_pool._interrupt_flag.is_set()) | ||||
|         self.assertIsNone(self.task_pool._cancel_group_meta_tasks(BAR)) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) | ||||
|         self.assertSetEqual({mock_task1, mock_task2}, self.task_pool._meta_tasks_cancelled) | ||||
|         mock_task1.cancel.assert_called_once_with() | ||||
|         mock_task2.cancel.assert_called_once_with() | ||||
|  | ||||
|         self.task_pool._ended = {123: mock_ended_func()} | ||||
|         self.task_pool._cancelled = {456: mock_cancelled_func()} | ||||
|         output = await self.task_pool.flush(return_exceptions=True) | ||||
|         self.assertListEqual([FOO, test_exception], output) | ||||
|         self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) | ||||
|         self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) | ||||
|     @patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks') | ||||
|     def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock): | ||||
|         task_id = 555 | ||||
|         mock_cancel = MagicMock() | ||||
|  | ||||
|     async def test_gather(self): | ||||
|         test_exception = TestException() | ||||
|         mock_ended_func, mock_cancelled_func = AsyncMock(return_value=FOO), AsyncMock(side_effect=test_exception) | ||||
|         mock_running_func = AsyncMock(return_value=BAR) | ||||
|         mock_queue_join = AsyncMock() | ||||
|         self.task_pool._before_gathering = before_gather = [mock_queue_join()] | ||||
|         self.task_pool._ended = ended = {123: mock_ended_func()} | ||||
|         self.task_pool._cancelled = cancelled = {456: mock_cancelled_func()} | ||||
|         self.task_pool._running = running = {789: mock_running_func()} | ||||
|         self.task_pool._interrupt_flag.set() | ||||
|         def add_mock_task_to_running(_): | ||||
|             self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel) | ||||
|         # We add the fake task to the `_tasks_running` dictionary as a side effect of calling the mocked method, | ||||
|         # to verify that it is called first, before the cancellation loop starts. | ||||
|         mock__cancel_group_meta_tasks.side_effect = add_mock_task_to_running | ||||
|  | ||||
|         assert not self.task_pool._locked | ||||
|         with self.assertRaises(exceptions.PoolStillUnlocked): | ||||
|             await self.task_pool.gather() | ||||
|         self.assertDictEqual(self.task_pool._ended, ended) | ||||
|         self.assertDictEqual(self.task_pool._cancelled, cancelled) | ||||
|         self.assertDictEqual(self.task_pool._running, running) | ||||
|         self.assertListEqual(self.task_pool._before_gathering, before_gather) | ||||
|         self.assertTrue(self.task_pool._interrupt_flag.is_set()) | ||||
|         class MockRegister(set, MagicMock): | ||||
|             pass | ||||
|         self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO)) | ||||
|         mock_cancel.assert_called_once_with(msg=FOO) | ||||
|  | ||||
|         self.task_pool._locked = True | ||||
|     @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') | ||||
|     def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock): | ||||
|         self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock() | ||||
|         with self.assertRaises(exceptions.InvalidGroupName): | ||||
|             self.task_pool.cancel_group(BAR) | ||||
|         mock__cancel_and_remove_all_from_group.assert_not_called() | ||||
|         self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR)) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) | ||||
|         mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR) | ||||
|  | ||||
|         def check_assertions(output) -> None: | ||||
|             self.assertListEqual([FOO, test_exception, BAR], output) | ||||
|             self.assertDictEqual(self.task_pool._ended, EMPTY_DICT) | ||||
|             self.assertDictEqual(self.task_pool._cancelled, EMPTY_DICT) | ||||
|             self.assertDictEqual(self.task_pool._running, EMPTY_DICT) | ||||
|             self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST) | ||||
|             self.assertFalse(self.task_pool._interrupt_flag.is_set()) | ||||
|     @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') | ||||
|     def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock): | ||||
|         mock_group_reg = MagicMock() | ||||
|         self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg} | ||||
|         self.assertIsNone(self.task_pool.cancel_all('msg')) | ||||
|         mock__cancel_and_remove_all_from_group.assert_has_calls([ | ||||
|             call(BAR, mock_group_reg, msg='msg'), | ||||
|             call(FOO, mock_group_reg, msg='msg') | ||||
|         ]) | ||||
|  | ||||
|         check_assertions(await self.task_pool.gather(return_exceptions=True)) | ||||
|     def test__pop_ended_meta_tasks(self): | ||||
|         mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True) | ||||
|         self.task_pool._group_meta_tasks_running[FOO] = {mock_task, mock_done_task1} | ||||
|         mock_done_task2, mock_done_task3 = MagicMock(done=lambda: True), MagicMock(done=lambda: True) | ||||
|         self.task_pool._group_meta_tasks_running[BAR] = {mock_done_task2, mock_done_task3} | ||||
|         expected_output = {mock_done_task1, mock_done_task2, mock_done_task3} | ||||
|         output = self.task_pool._pop_ended_meta_tasks() | ||||
|         self.assertSetEqual(expected_output, output) | ||||
|         self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running) | ||||
|  | ||||
|         self.task_pool._before_gathering = [mock_queue_join()] | ||||
|         self.task_pool._ended = {123: mock_ended_func()} | ||||
|         self.task_pool._cancelled = {456: mock_cancelled_func()} | ||||
|         self.task_pool._running = {789: mock_running_func()} | ||||
|         check_assertions(await self.task_pool.gather(return_exceptions=True)) | ||||
|     @patch.object(pool.BaseTaskPool, '_pop_ended_meta_tasks') | ||||
|     async def test_flush(self, mock__pop_ended_meta_tasks: MagicMock): | ||||
|         # Meta tasks: | ||||
|         mock_ended_meta_task = AsyncMock() | ||||
|         mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()} | ||||
|         mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) | ||||
|         self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} | ||||
|         # Actual tasks: | ||||
|         mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) | ||||
|         self.task_pool._tasks_ended = {123: mock_ended_func()} | ||||
|         self.task_pool._tasks_cancelled = {456: mock_cancelled_func()} | ||||
|  | ||||
|         self.assertIsNone(await self.task_pool.flush(return_exceptions=True)) | ||||
|  | ||||
|         # Meta tasks: | ||||
|         mock__pop_ended_meta_tasks.assert_called_once_with() | ||||
|         mock_ended_meta_task.assert_awaited_once_with() | ||||
|         mock_cancelled_meta_task.assert_awaited_once_with() | ||||
|         self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) | ||||
|         # Actual tasks: | ||||
|         mock_ended_func.assert_awaited_once_with() | ||||
|         mock_cancelled_func.assert_awaited_once_with() | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) | ||||
|  | ||||
|     @patch.object(pool.BaseTaskPool, 'lock') | ||||
|     async def test_gather_and_close(self, mock_lock: MagicMock): | ||||
|         # Meta tasks: | ||||
|         mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock() | ||||
|         self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}} | ||||
|         mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) | ||||
|         self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} | ||||
|         # Actual tasks: | ||||
|         mock_running_func = AsyncMock() | ||||
|         mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception) | ||||
|         self.task_pool._tasks_ended = {123: mock_ended_func()} | ||||
|         self.task_pool._tasks_cancelled = {456: mock_cancelled_func()} | ||||
|         self.task_pool._tasks_running = {789: mock_running_func()} | ||||
|  | ||||
|         self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) | ||||
|  | ||||
|         mock_lock.assert_called_once_with() | ||||
|         # Meta tasks: | ||||
|         mock_meta_task1.assert_awaited_once_with() | ||||
|         mock_meta_task2.assert_awaited_once_with() | ||||
|         mock_cancelled_meta_task.assert_awaited_once_with() | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) | ||||
|         self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) | ||||
|         # Actual tasks: | ||||
|         mock_ended_func.assert_awaited_once_with() | ||||
|         mock_cancelled_func.assert_awaited_once_with() | ||||
|         mock_running_func.assert_awaited_once_with() | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) | ||||
|         self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) | ||||
|         self.assertTrue(self.task_pool._closed) | ||||
|  | ||||
|  | ||||
| class TaskPoolTestCase(CommonTestCase): | ||||
|     TEST_CLASS = pool.TaskPool | ||||
|     task_pool: pool.TaskPool | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_start_task') | ||||
|     async def test__apply_one(self, mock__start_task: AsyncMock): | ||||
|         mock__start_task.return_value = expected_output = 12345 | ||||
|         mock_awaitable = MagicMock() | ||||
|         mock_func = MagicMock(return_value=mock_awaitable) | ||||
|         args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         output = await self.task_pool._apply_one(mock_func, args, kwargs, end_cb, cancel_cb) | ||||
|     def test__generate_group_name(self): | ||||
|         prefix, func = 'x y z', AsyncMock(__name__=BAR) | ||||
|         base_name = f'{prefix}-{BAR}-group' | ||||
|         self.task_pool._task_groups = { | ||||
|             f'{base_name}-0': MagicMock(), | ||||
|             f'{base_name}-1': MagicMock(), | ||||
|             f'{base_name}-100': MagicMock(), | ||||
|         } | ||||
|         expected_output = f'{base_name}-2' | ||||
|         output = self.task_pool._generate_group_name(prefix, func) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_func.assert_called_once_with(*args, **kwargs) | ||||
|         mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_start_task') | ||||
|     async def test__apply_num(self, mock__start_task: AsyncMock): | ||||
|         group_name = FOO + BAR | ||||
|         mock_awaitable = object() | ||||
|         mock_func = MagicMock(return_value=mock_awaitable) | ||||
|         args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)) | ||||
|         mock_func.assert_has_calls(3 * [call(*args, **kwargs)]) | ||||
|         mock__start_task.assert_has_awaits(3 * [ | ||||
|             call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         ]) | ||||
|  | ||||
|         mock_func.reset_mock() | ||||
|         mock__start_task.reset_mock() | ||||
|  | ||||
|         output = await self.task_pool._apply_one(mock_func, args, None, end_cb, cancel_cb) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_func.assert_called_once_with(*args) | ||||
|         mock__start_task.assert_awaited_once_with(mock_awaitable, end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb)) | ||||
|         mock_func.assert_has_calls(num * [call(*args)]) | ||||
|         mock__start_task.assert_has_awaits(num * [ | ||||
|             call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         ]) | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_apply_one') | ||||
|     async def test_apply(self, mock__apply_one: AsyncMock): | ||||
|         mock__apply_one.return_value = mock_id = 67890 | ||||
|         mock_func, num = MagicMock(), 3 | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock()) | ||||
|     @patch.object(pool, 'TaskGroupRegister') | ||||
|     @patch.object(pool.TaskPool, '_generate_group_name') | ||||
|     @patch.object(pool.BaseTaskPool, '_check_start') | ||||
|     def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, | ||||
|                    mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): | ||||
|         mock__generate_group_name.return_value = generated_name = 'name 123' | ||||
|         mock_group_reg = set_up_mock_group_register(mock_reg_cls) | ||||
|         mock__apply_num.return_value = mock_apply_coroutine = object() | ||||
|         mock_create_task.return_value = fake_task = object() | ||||
|         mock_func, num, group_name = MagicMock(), 3, FOO + BAR | ||||
|         args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         expected_output = num * [mock_id] | ||||
|         output = await self.task_pool.apply(mock_func, args, kwargs, num, end_cb, cancel_cb) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock__apply_one.assert_has_awaits(num * [call(mock_func, args, kwargs, end_cb, cancel_cb)]) | ||||
|         self.task_pool._task_groups = {} | ||||
|  | ||||
|     async def test__queue_producer(self): | ||||
|         mock_put = AsyncMock() | ||||
|         mock_q = MagicMock(put=mock_put) | ||||
|         args = (FOO, BAR, 123) | ||||
|         assert not self.task_pool._interrupt_flag.is_set() | ||||
|         self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) | ||||
|         mock_put.assert_has_awaits([call(arg) for arg in args]) | ||||
|         mock_put.reset_mock() | ||||
|         self.task_pool._interrupt_flag.set() | ||||
|         self.assertIsNone(await self.task_pool._queue_producer(mock_q, args)) | ||||
|         mock_put.assert_not_awaited() | ||||
|         def check_assertions(_group_name, _output): | ||||
|             self.assertEqual(_group_name, _output) | ||||
|             mock__check_start.assert_called_once_with(function=mock_func) | ||||
|             self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) | ||||
|             mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, | ||||
|                                                     end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|             mock_create_task.assert_called_once_with(mock_apply_coroutine) | ||||
|             self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) | ||||
|  | ||||
|     @patch.object(pool, 'partial') | ||||
|     @patch.object(pool, 'star_function') | ||||
|     @patch.object(pool.TaskPool, '_start_task') | ||||
|     async def test__queue_consumer(self, mock__start_task: AsyncMock, mock_star_function: MagicMock, | ||||
|                                    mock_partial: MagicMock): | ||||
|         mock_partial.return_value = queue_callback = 'not really' | ||||
|         mock_star_function.return_value = awaitable = 'totally an awaitable' | ||||
|         q, arg = Queue(), 420.69 | ||||
|         q.put_nowait(arg) | ||||
|         mock_func, stars = MagicMock(), 3 | ||||
|         mock_flag, end_cb, cancel_cb = MagicMock(), MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) | ||||
|         self.assertTrue(q.empty()) | ||||
|         mock__start_task.assert_awaited_once_with(awaitable, ignore_lock=True, | ||||
|                                                   end_callback=queue_callback, cancel_callback=cancel_cb) | ||||
|         mock_star_function.assert_called_once_with(mock_func, arg, arg_stars=stars) | ||||
|         mock_partial.assert_called_once_with(pool.TaskPool._queue_callback, self.task_pool, | ||||
|                                              q=q, first_batch_started=mock_flag, func=mock_func, arg_stars=stars, | ||||
|                                              end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__start_task.reset_mock() | ||||
|         mock_star_function.reset_mock() | ||||
|         mock_partial.reset_mock() | ||||
|         output = self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) | ||||
|         check_assertions(group_name, output) | ||||
|         mock__generate_group_name.assert_not_called() | ||||
|  | ||||
|         self.assertIsNone(await self.task_pool._queue_consumer(q, mock_flag, mock_func, stars, end_cb, cancel_cb)) | ||||
|         self.assertTrue(q.empty()) | ||||
|         mock__start_task.assert_not_awaited() | ||||
|         mock_star_function.assert_not_called() | ||||
|         mock_partial.assert_not_called() | ||||
|  | ||||
|     @patch.object(pool, 'execute_optional') | ||||
|     @patch.object(pool.TaskPool, '_queue_consumer') | ||||
|     async def test__queue_callback(self, mock__queue_consumer: AsyncMock, mock_execute_optional: AsyncMock): | ||||
|         task_id, mock_q = 420, MagicMock() | ||||
|         mock_func, stars = MagicMock(), 3 | ||||
|         mock_wait = AsyncMock() | ||||
|         mock_flag = MagicMock(wait=mock_wait) | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool._queue_callback(task_id, mock_q, mock_flag, mock_func, stars, | ||||
|                                                                end_callback=end_cb, cancel_callback=cancel_cb)) | ||||
|         mock_wait.assert_awaited_once_with() | ||||
|         mock__queue_consumer.assert_awaited_once_with(mock_q, mock_flag, mock_func, stars, | ||||
|                                                       end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock_execute_optional.assert_awaited_once_with(end_cb, args=(task_id,)) | ||||
|  | ||||
|     @patch.object(pool, 'iter') | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool, 'join_queue', new_callable=MagicMock) | ||||
|     @patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock) | ||||
|     async def test__set_up_args_queue(self, mock__queue_producer: MagicMock, mock_join_queue: MagicMock, | ||||
|                                       mock_create_task: MagicMock, mock_iter: MagicMock): | ||||
|         args, num_tasks = (FOO, BAR, 1, 2, 3), 2 | ||||
|         mock_join_queue.return_value = mock_join = 'awaitable' | ||||
|         mock_iter.return_value = args_iter = iter(args) | ||||
|         mock__queue_producer.return_value = mock_producer_coro = 'very awaitable' | ||||
|         output_q = self.task_pool._set_up_args_queue(args, num_tasks) | ||||
|         self.assertIsInstance(output_q, Queue) | ||||
|         self.assertEqual(num_tasks, output_q.qsize()) | ||||
|         for arg in args[:num_tasks]: | ||||
|             self.assertEqual(arg, output_q.get_nowait()) | ||||
|         self.assertTrue(output_q.empty()) | ||||
|         for arg in args[num_tasks:]: | ||||
|             self.assertEqual(arg, next(args_iter)) | ||||
|         with self.assertRaises(StopIteration): | ||||
|             next(args_iter) | ||||
|         self.assertListEqual([mock_join], self.task_pool._before_gathering) | ||||
|         mock_join_queue.assert_called_once_with(output_q) | ||||
|         mock__queue_producer.assert_called_once_with(output_q, args_iter) | ||||
|         mock_create_task.assert_called_once_with(mock_producer_coro) | ||||
|  | ||||
|         self.task_pool._before_gathering.clear() | ||||
|         mock_join_queue.reset_mock() | ||||
|         mock__queue_producer.reset_mock() | ||||
|         mock__check_start.reset_mock() | ||||
|         self.task_pool._task_groups.clear() | ||||
|         mock__apply_num.reset_mock() | ||||
|         mock_create_task.reset_mock() | ||||
|  | ||||
|         num_tasks = 6 | ||||
|         mock_iter.return_value = args_iter = iter(args) | ||||
|         output_q = self.task_pool._set_up_args_queue(args, num_tasks) | ||||
|         self.assertIsInstance(output_q, Queue) | ||||
|         self.assertEqual(len(args), output_q.qsize()) | ||||
|         for arg in args: | ||||
|             self.assertEqual(arg, output_q.get_nowait()) | ||||
|         self.assertTrue(output_q.empty()) | ||||
|         with self.assertRaises(StopIteration): | ||||
|             next(args_iter) | ||||
|         self.assertListEqual([mock_join], self.task_pool._before_gathering) | ||||
|         mock_join_queue.assert_called_once_with(output_q) | ||||
|         mock__queue_producer.assert_not_called() | ||||
|         mock_create_task.assert_not_called() | ||||
|         output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) | ||||
|         check_assertions(generated_name, output) | ||||
|         mock__generate_group_name.assert_called_once_with('apply', mock_func) | ||||
|  | ||||
|     @patch.object(pool, 'Event') | ||||
|     @patch.object(pool.TaskPool, '_queue_consumer') | ||||
|     @patch.object(pool.TaskPool, '_set_up_args_queue') | ||||
|     async def test__map(self, mock__set_up_args_queue: MagicMock, mock__queue_consumer: AsyncMock, | ||||
|                         mock_event_cls: MagicMock): | ||||
|         qsize = 4 | ||||
|         mock__set_up_args_queue.return_value = mock_q = MagicMock(qsize=MagicMock(return_value=qsize)) | ||||
|         mock_flag_set = MagicMock() | ||||
|         mock_event_cls.return_value = mock_flag = MagicMock(set=mock_flag_set) | ||||
|     @patch.object(pool, 'execute_optional') | ||||
|     async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): | ||||
|         semaphore, mock_end_cb = Semaphore(1), MagicMock() | ||||
|         wrapped = pool.TaskPool._get_map_end_callback(semaphore, mock_end_cb) | ||||
|         task_id = 1234 | ||||
|         await wrapped(task_id) | ||||
|         self.assertEqual(2, semaphore._value) | ||||
|         mock_execute_optional.assert_awaited_once_with(mock_end_cb, args=(task_id,)) | ||||
|  | ||||
|         mock_func, stars = MagicMock(), 3 | ||||
|         args_iter, group_size = (FOO, BAR, 1, 2, 3), 2 | ||||
|     @patch.object(pool, 'star_function') | ||||
|     @patch.object(pool.TaskPool, '_start_task') | ||||
|     @patch.object(pool.TaskPool, '_get_map_end_callback') | ||||
|     @patch.object(pool, 'Semaphore') | ||||
|     async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock, | ||||
|                                    mock__start_task: AsyncMock, mock_star_function: MagicMock): | ||||
|         n = 2 | ||||
|         mock_semaphore_cls.return_value = semaphore = Semaphore(n) | ||||
|         mock__get_map_end_callback.return_value = map_cb = MagicMock() | ||||
|         awaitable = 'totally an awaitable' | ||||
|         mock_star_function.side_effect = [awaitable, Exception(), awaitable] | ||||
|         arg1, arg2, bad = 123456789, 'function argument', None | ||||
|         args = [arg1, bad, arg2] | ||||
|         group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3 | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) | ||||
|         # We expect the semaphore to be acquired 2 times, then be released once after the exception occurs, then | ||||
|         # acquired once more is reached. Since we initialized it with a value of 2, we expect it be locked. | ||||
|         self.assertTrue(semaphore.locked()) | ||||
|         mock_semaphore_cls.assert_called_once_with(n) | ||||
|         mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) | ||||
|         mock__start_task.assert_has_awaits(2 * [ | ||||
|             call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) | ||||
|         ]) | ||||
|         mock_star_function.assert_has_calls([ | ||||
|             call(mock_func, arg1, arg_stars=stars), | ||||
|             call(mock_func, bad, arg_stars=stars), | ||||
|             call(mock_func, arg2, arg_stars=stars) | ||||
|         ]) | ||||
|  | ||||
|         mock_semaphore_cls.reset_mock() | ||||
|         mock__get_map_end_callback.reset_mock() | ||||
|         mock__start_task.reset_mock() | ||||
|         mock_star_function.reset_mock() | ||||
|  | ||||
|         # With a CancelledError thrown while starting a task: | ||||
|         mock_semaphore_cls.return_value = semaphore = Semaphore(1) | ||||
|         mock_star_function.side_effect = CancelledError() | ||||
|         self.assertIsNone(await self.task_pool._arg_consumer(group_name, n, mock_func, args, stars, end_cb, cancel_cb)) | ||||
|         self.assertFalse(semaphore.locked()) | ||||
|         mock_semaphore_cls.assert_called_once_with(n) | ||||
|         mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) | ||||
|         mock__start_task.assert_not_called() | ||||
|         mock_star_function.assert_called_once_with(mock_func, arg1, arg_stars=stars) | ||||
|  | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock) | ||||
|     @patch.object(pool, 'TaskGroupRegister') | ||||
|     @patch.object(pool.BaseTaskPool, '_check_start') | ||||
|     def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock, | ||||
|                   mock_create_task: MagicMock): | ||||
|         mock_group_reg = set_up_mock_group_register(mock_reg_cls) | ||||
|         mock__arg_consumer.return_value = fake_consumer = object() | ||||
|         mock_create_task.return_value = fake_task = object() | ||||
|  | ||||
|         group_name, n = 'onetwothree', 0 | ||||
|         func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|  | ||||
|         self.task_pool._locked = False | ||||
|         with self.assertRaises(exceptions.PoolIsLocked): | ||||
|             await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb) | ||||
|         mock__set_up_args_queue.assert_not_called() | ||||
|         mock__queue_consumer.assert_not_awaited() | ||||
|         mock_flag_set.assert_not_called() | ||||
|         with self.assertRaises(ValueError): | ||||
|             self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|  | ||||
|         self.task_pool._locked = True | ||||
|         self.assertIsNone(await self.task_pool._map(mock_func, args_iter, stars, group_size, end_cb, cancel_cb)) | ||||
|         mock__set_up_args_queue.assert_called_once_with(args_iter, group_size) | ||||
|         mock__queue_consumer.assert_has_awaits(qsize * [call(mock_q, mock_flag, mock_func, arg_stars=stars, | ||||
|                                                              end_callback=end_cb, cancel_callback=cancel_cb)]) | ||||
|         mock_flag_set.assert_called_once_with() | ||||
|         mock__check_start.reset_mock() | ||||
|  | ||||
|         n = 1234 | ||||
|         self.task_pool._task_groups = {group_name: MagicMock()} | ||||
|  | ||||
|         with self.assertRaises(exceptions.InvalidGroupName): | ||||
|             self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|  | ||||
|         mock__check_start.reset_mock() | ||||
|  | ||||
|         self.task_pool._task_groups.clear() | ||||
|  | ||||
|         self.assertIsNone(self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)) | ||||
|         mock__check_start.assert_called_once_with(function=func) | ||||
|         mock_reg_cls.assert_called_once_with() | ||||
|         self.task_pool._task_groups[group_name] = mock_group_reg | ||||
|         mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars, | ||||
|                                                    end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock_create_task.assert_called_once_with(fake_consumer) | ||||
|         self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name]) | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_map') | ||||
|     async def test_map(self, mock__map: AsyncMock): | ||||
|     @patch.object(pool.TaskPool, '_generate_group_name') | ||||
|     def test_map(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): | ||||
|         mock__generate_group_name.return_value = generated_name = 'name 1 2 3' | ||||
|         mock_func = MagicMock() | ||||
|         arg_iter, group_size = (FOO, BAR, 1, 2, 3), 2 | ||||
|         arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool.map(mock_func, arg_iter, group_size, end_cb, cancel_cb)) | ||||
|         mock__map.assert_awaited_once_with(mock_func, arg_iter, arg_stars=0, group_size=group_size, | ||||
|                                            end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         output = self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb) | ||||
|         self.assertEqual(group_name, output) | ||||
|         mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, arg_iter, 0, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_not_called() | ||||
|  | ||||
|         mock__map.reset_mock() | ||||
|         output = self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb) | ||||
|         self.assertEqual(generated_name, output) | ||||
|         mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_called_once_with('map', mock_func) | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_map') | ||||
|     async def test_starmap(self, mock__map: AsyncMock): | ||||
|     @patch.object(pool.TaskPool, '_generate_group_name') | ||||
|     def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): | ||||
|         mock__generate_group_name.return_value = generated_name = 'name 1 2 3' | ||||
|         mock_func = MagicMock() | ||||
|         args_iter, group_size = ([FOO], [BAR]), 2 | ||||
|         args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool.starmap(mock_func, args_iter, group_size, end_cb, cancel_cb)) | ||||
|         mock__map.assert_awaited_once_with(mock_func, args_iter, arg_stars=1, group_size=group_size, | ||||
|                                            end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb) | ||||
|         self.assertEqual(group_name, output) | ||||
|         mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, args_iter, 1, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_not_called() | ||||
|  | ||||
|         mock__map.reset_mock() | ||||
|         output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb) | ||||
|         self.assertEqual(generated_name, output) | ||||
|         mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, args_iter, 1, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_called_once_with('starmap', mock_func) | ||||
|  | ||||
|     @patch.object(pool.TaskPool, '_map') | ||||
|     async def test_doublestarmap(self, mock__map: AsyncMock): | ||||
|     @patch.object(pool.TaskPool, '_generate_group_name') | ||||
|     async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock): | ||||
|         mock__generate_group_name.return_value = generated_name = 'name 1 2 3' | ||||
|         mock_func = MagicMock() | ||||
|         kwargs_iter, group_size = [{'a': FOO}, {'a': BAR}], 2 | ||||
|         kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR | ||||
|         end_cb, cancel_cb = MagicMock(), MagicMock() | ||||
|         self.assertIsNone(await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, end_cb, cancel_cb)) | ||||
|         mock__map.assert_awaited_once_with(mock_func, kwargs_iter, arg_stars=2, group_size=group_size, | ||||
|                                            end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb) | ||||
|         self.assertEqual(group_name, output) | ||||
|         mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, kw_iter, 2, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_not_called() | ||||
|  | ||||
|         mock__map.reset_mock() | ||||
|         output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb) | ||||
|         self.assertEqual(generated_name, output) | ||||
|         mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2, | ||||
|                                           end_callback=end_cb, cancel_callback=cancel_cb) | ||||
|         mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func) | ||||
|  | ||||
|  | ||||
| class SimpleTaskPoolTestCase(CommonTestCase): | ||||
| @@ -645,29 +711,48 @@ class SimpleTaskPoolTestCase(CommonTestCase): | ||||
|         self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name) | ||||
|  | ||||
|     @patch.object(pool.SimpleTaskPool, '_start_task') | ||||
|     async def test__start_one(self, mock__start_task: AsyncMock): | ||||
|         mock__start_task.return_value = expected_output = 99 | ||||
|         self.task_pool._func = MagicMock(return_value=BAR) | ||||
|         output = await self.task_pool._start_one() | ||||
|         self.assertEqual(expected_output, output) | ||||
|         self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs) | ||||
|         mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback, | ||||
|                                                   cancel_callback=self.task_pool._cancel_callback) | ||||
|     async def test__start_num(self, mock__start_task: AsyncMock): | ||||
|         fake_coroutine = object() | ||||
|         self.task_pool._func = MagicMock(return_value=fake_coroutine) | ||||
|         num = 3 | ||||
|         group_name = FOO + BAR + 'abc' | ||||
|         self.assertIsNone(await self.task_pool._start_num(num, group_name)) | ||||
|         self.task_pool._func.assert_has_calls(num * [ | ||||
|             call(*self.task_pool._args, **self.task_pool._kwargs) | ||||
|         ]) | ||||
|         mock__start_task.assert_has_awaits(num * [ | ||||
|             call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback, | ||||
|                  cancel_callback=self.task_pool._cancel_callback) | ||||
|         ]) | ||||
|  | ||||
|     @patch.object(pool.SimpleTaskPool, '_start_one') | ||||
|     async def test_start(self, mock__start_one: AsyncMock): | ||||
|         mock__start_one.return_value = FOO | ||||
|     @patch.object(pool, 'create_task') | ||||
|     @patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock()) | ||||
|     @patch.object(pool, 'TaskGroupRegister') | ||||
|     @patch.object(pool.BaseTaskPool, '_check_start') | ||||
|     def test_start(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__start_num: AsyncMock, | ||||
|                    mock_create_task: MagicMock): | ||||
|         mock_group_reg = set_up_mock_group_register(mock_reg_cls) | ||||
|         mock__start_num.return_value = mock_start_num_coroutine = object() | ||||
|         mock_create_task.return_value = fake_task = object() | ||||
|         self.task_pool._task_groups = {} | ||||
|         self.task_pool._group_meta_tasks_running = {} | ||||
|         num = 5 | ||||
|         output = await self.task_pool.start(num) | ||||
|         expected_output = num * [FOO] | ||||
|         self.assertListEqual(expected_output, output) | ||||
|         mock__start_one.assert_has_awaits(num * [call()]) | ||||
|         self.task_pool._start_calls = 42 | ||||
|         expected_group_name = 'start-group-42' | ||||
|         output = self.task_pool.start(num) | ||||
|         self.assertEqual(expected_group_name, output) | ||||
|         mock__check_start.assert_called_once_with(function=self.TEST_POOL_FUNC) | ||||
|         self.assertEqual(43, self.task_pool._start_calls) | ||||
|         self.assertEqual(mock_group_reg, self.task_pool._task_groups[expected_group_name]) | ||||
|         mock__start_num.assert_called_once_with(num, expected_group_name) | ||||
|         mock_create_task.assert_called_once_with(mock_start_num_coroutine) | ||||
|         self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[expected_group_name]) | ||||
|  | ||||
|     @patch.object(pool.SimpleTaskPool, 'cancel') | ||||
|     def test_stop(self, mock_cancel: MagicMock): | ||||
|         num = 2 | ||||
|         id1, id2, id3 = 5, 6, 7 | ||||
|         self.task_pool._running = {id1: FOO, id2: BAR, id3: FOO + BAR} | ||||
|         self.task_pool._tasks_running = {id1: FOO, id2: BAR, id3: FOO + BAR} | ||||
|         output = self.task_pool.stop(num) | ||||
|         expected_output = [id3, id2] | ||||
|         self.assertEqual(expected_output, output) | ||||
| @@ -689,3 +774,10 @@ class SimpleTaskPoolTestCase(CommonTestCase): | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_num_running.assert_called_once_with() | ||||
|         mock_stop.assert_called_once_with(num) | ||||
|  | ||||
|  | ||||
| def set_up_mock_group_register(mock_reg_cls: MagicMock) -> MagicMock: | ||||
|     mock_grp_aenter, mock_grp_aexit, mock_grp_add = AsyncMock(), AsyncMock(), MagicMock() | ||||
|     mock_reg_cls.return_value = mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit, | ||||
|                                                            add=mock_grp_add) | ||||
|     return mock_group_reg | ||||
|   | ||||
							
								
								
									
										43
									
								
								tests/test_queue_context.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								tests/test_queue_context.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.queue_context` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| from asyncio_taskpool.queue_context import Queue | ||||
|  | ||||
|  | ||||
| class QueueTestCase(IsolatedAsyncioTestCase): | ||||
|     def test_item_processed(self): | ||||
|         queue = Queue() | ||||
|         queue._unfinished_tasks = 1000 | ||||
|         queue.item_processed() | ||||
|         self.assertEqual(999, queue._unfinished_tasks) | ||||
|  | ||||
|     @patch.object(Queue, 'item_processed') | ||||
|     async def test_contextmanager(self, mock_item_processed: MagicMock): | ||||
|         queue = Queue() | ||||
|         item = 'foo' | ||||
|         queue.put_nowait(item) | ||||
|         async with queue as item_from_queue: | ||||
|             self.assertEqual(item, item_from_queue) | ||||
|             mock_item_processed.assert_not_called() | ||||
|         mock_item_processed.assert_called_once_with() | ||||
| @@ -1,324 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.session` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import json | ||||
| from argparse import ArgumentError, Namespace | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import AsyncMock, MagicMock, patch, call | ||||
|  | ||||
| from asyncio_taskpool import session | ||||
| from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, SESSION_WRITER | ||||
| from asyncio_taskpool.exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass | ||||
| from asyncio_taskpool.pool import BaseTaskPool, TaskPool, SimpleTaskPool | ||||
|  | ||||
|  | ||||
| FOO, BAR = 'foo', 'bar' | ||||
|  | ||||
|  | ||||
| class ControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|     log_lvl: int | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls) -> None: | ||||
|         cls.log_lvl = session.log.level | ||||
|         session.log.setLevel(999) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls) -> None: | ||||
|         session.log.setLevel(cls.log_lvl) | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.mock_pool = MagicMock(spec=SimpleTaskPool(AsyncMock())) | ||||
|         self.mock_client_class_name = FOO + BAR | ||||
|         self.mock_server = MagicMock(pool=self.mock_pool, | ||||
|                                      client_class_name=self.mock_client_class_name) | ||||
|         self.mock_reader = MagicMock() | ||||
|         self.mock_writer = MagicMock() | ||||
|         self.session = session.ControlSession(self.mock_server, self.mock_reader, self.mock_writer) | ||||
|  | ||||
|     def test_init(self): | ||||
|         self.assertEqual(self.mock_server, self.session._control_server) | ||||
|         self.assertEqual(self.mock_pool, self.session._pool) | ||||
|         self.assertEqual(self.mock_client_class_name, self.session._client_class_name) | ||||
|         self.assertEqual(self.mock_reader, self.session._reader) | ||||
|         self.assertEqual(self.mock_writer, self.session._writer) | ||||
|         self.assertIsNone(self.session._parser) | ||||
|         self.assertIsNone(self.session._subparsers) | ||||
|  | ||||
|     def test__add_command(self): | ||||
|         expected_output = 123456 | ||||
|         mock_add_parser = MagicMock(return_value=expected_output) | ||||
|         self.session._subparsers = MagicMock(add_parser=mock_add_parser) | ||||
|         self.session._parser = MagicMock() | ||||
|         name, prog, short_help, long_help = 'abc', None, 'short123', None | ||||
|         kwargs = {'x': 1, 'y': 2} | ||||
|         output = self.session._add_command(name, prog, short_help, long_help, **kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_add_parser.assert_called_once_with(name, prog=name, help=short_help, description=short_help, | ||||
|                                                 parent=self.session._parser, **kwargs) | ||||
|  | ||||
|         mock_add_parser.reset_mock() | ||||
|  | ||||
|         prog, long_help = 'ffffff', 'so long, wow' | ||||
|         output = self.session._add_command(name, prog, short_help, long_help, **kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_add_parser.assert_called_once_with(name, prog=prog, help=short_help, description=long_help, | ||||
|                                                 parent=self.session._parser, **kwargs) | ||||
|  | ||||
|         mock_add_parser.reset_mock() | ||||
|  | ||||
|         short_help = None | ||||
|         output = self.session._add_command(name, prog, short_help, long_help, **kwargs) | ||||
|         self.assertEqual(expected_output, output) | ||||
|         mock_add_parser.assert_called_once_with(name, prog=prog, help=long_help, description=long_help, | ||||
|                                                 parent=self.session._parser, **kwargs) | ||||
|  | ||||
|     @patch.object(session, 'get_first_doc_line') | ||||
|     @patch.object(session.ControlSession, '_add_command') | ||||
|     def test__adding_commands(self, mock__add_command: MagicMock, mock_get_first_doc_line: MagicMock): | ||||
|         self.assertIsNone(self.session._add_base_commands()) | ||||
|         mock__add_command.assert_called() | ||||
|         mock_get_first_doc_line.assert_called() | ||||
|  | ||||
|         mock__add_command.reset_mock() | ||||
|         mock_get_first_doc_line.reset_mock() | ||||
|  | ||||
|         self.assertIsNone(self.session._add_simple_commands()) | ||||
|         mock__add_command.assert_called() | ||||
|         mock_get_first_doc_line.assert_called() | ||||
|  | ||||
|         with self.assertRaises(NotImplementedError): | ||||
|             self.session._add_advanced_commands() | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_add_simple_commands') | ||||
|     @patch.object(session.ControlSession, '_add_advanced_commands') | ||||
|     @patch.object(session.ControlSession, '_add_base_commands') | ||||
|     @patch.object(session, 'CommandParser') | ||||
|     def test__init_parser(self, mock_command_parser_cls: MagicMock, mock__add_base_commands: MagicMock, | ||||
|                           mock__add_advanced_commands: MagicMock, mock__add_simple_commands: MagicMock): | ||||
|         mock_command_parser_cls.return_value = mock_parser = MagicMock() | ||||
|         self.session._pool = TaskPool() | ||||
|         width = 1234 | ||||
|         expected_parser_kwargs = { | ||||
|             'prog': '', | ||||
|             SESSION_WRITER: self.mock_writer, | ||||
|             CLIENT_INFO.TERMINAL_WIDTH: width, | ||||
|         } | ||||
|         self.assertIsNone(self.session._init_parser(width)) | ||||
|         mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) | ||||
|         mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) | ||||
|         mock__add_base_commands.assert_called_once_with() | ||||
|         mock__add_advanced_commands.assert_called_once_with() | ||||
|         mock__add_simple_commands.assert_not_called() | ||||
|  | ||||
|         mock_command_parser_cls.reset_mock() | ||||
|         mock_parser.add_subparsers.reset_mock() | ||||
|         mock__add_base_commands.reset_mock() | ||||
|         mock__add_advanced_commands.reset_mock() | ||||
|         mock__add_simple_commands.reset_mock() | ||||
|  | ||||
|         async def fake_coroutine(): pass | ||||
|  | ||||
|         self.session._pool = SimpleTaskPool(fake_coroutine) | ||||
|         self.assertIsNone(self.session._init_parser(width)) | ||||
|         mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) | ||||
|         mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) | ||||
|         mock__add_base_commands.assert_called_once_with() | ||||
|         mock__add_advanced_commands.assert_not_called() | ||||
|         mock__add_simple_commands.assert_called_once_with() | ||||
|  | ||||
|         mock_command_parser_cls.reset_mock() | ||||
|         mock_parser.add_subparsers.reset_mock() | ||||
|         mock__add_base_commands.reset_mock() | ||||
|         mock__add_advanced_commands.reset_mock() | ||||
|         mock__add_simple_commands.reset_mock() | ||||
|  | ||||
|         class FakeTaskPool(BaseTaskPool): | ||||
|             pass | ||||
|  | ||||
|         self.session._pool = FakeTaskPool() | ||||
|         with self.assertRaises(UnknownTaskPoolClass): | ||||
|             self.session._init_parser(width) | ||||
|         mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) | ||||
|         mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) | ||||
|         mock__add_base_commands.assert_called_once_with() | ||||
|         mock__add_advanced_commands.assert_not_called() | ||||
|         mock__add_simple_commands.assert_not_called() | ||||
|  | ||||
|         mock_command_parser_cls.reset_mock() | ||||
|         mock_parser.add_subparsers.reset_mock() | ||||
|         mock__add_base_commands.reset_mock() | ||||
|         mock__add_advanced_commands.reset_mock() | ||||
|         mock__add_simple_commands.reset_mock() | ||||
|  | ||||
|         self.session._pool = MagicMock() | ||||
|         with self.assertRaises(NotATaskPool): | ||||
|             self.session._init_parser(width) | ||||
|         mock_command_parser_cls.assert_called_once_with(**expected_parser_kwargs) | ||||
|         mock_parser.add_subparsers.assert_called_once_with(title="Commands", dest=CMD.CMD) | ||||
|         mock__add_base_commands.assert_called_once_with() | ||||
|         mock__add_advanced_commands.assert_not_called() | ||||
|         mock__add_simple_commands.assert_not_called() | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_init_parser') | ||||
|     async def test_client_handshake(self, mock__init_parser: MagicMock): | ||||
|         width = 5678 | ||||
|         msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + '  ' | ||||
|         mock_read = AsyncMock(return_value=msg.encode()) | ||||
|         self.mock_reader.read = mock_read | ||||
|         self.mock_writer.drain = AsyncMock() | ||||
|         self.assertIsNone(await self.session.client_handshake()) | ||||
|         mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) | ||||
|         mock__init_parser.assert_called_once_with(width) | ||||
|         self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) | ||||
|         self.mock_writer.drain.assert_awaited_once_with() | ||||
|  | ||||
|     @patch.object(session, 'return_or_exception') | ||||
|     async def test__write_function_output(self, mock_return_or_exception: MagicMock): | ||||
|         self.mock_writer.write = MagicMock() | ||||
|         mock_return_or_exception.return_value = None | ||||
|         func, args, kwargs = MagicMock(), (1, 2, 3), {'a': 'A', 'b': 'B'} | ||||
|         self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) | ||||
|         mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) | ||||
|         self.mock_writer.write.assert_called_once_with(b"ok") | ||||
|  | ||||
|         mock_return_or_exception.reset_mock() | ||||
|         self.mock_writer.write.reset_mock() | ||||
|  | ||||
|         mock_return_or_exception.return_value = output = MagicMock() | ||||
|         self.assertIsNone(await self.session._write_function_output(func, *args, **kwargs)) | ||||
|         mock_return_or_exception.assert_called_once_with(func, *args, **kwargs) | ||||
|         self.mock_writer.write.assert_called_once_with(str(output).encode()) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_name(self, mock__write_function_output: AsyncMock): | ||||
|         self.assertIsNone(await self.session._cmd_name()) | ||||
|         mock__write_function_output.assert_awaited_once_with(self.mock_pool.__class__.__str__, self.session._pool) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_pool_size(self, mock__write_function_output: AsyncMock): | ||||
|         num = 12345 | ||||
|         kwargs = {session.NUM: num, FOO: BAR} | ||||
|         self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) | ||||
|         mock__write_function_output.assert_awaited_once_with( | ||||
|             self.mock_pool.__class__.pool_size.fset, self.session._pool, num | ||||
|         ) | ||||
|  | ||||
|         mock__write_function_output.reset_mock() | ||||
|  | ||||
|         kwargs.pop(session.NUM) | ||||
|         self.assertIsNone(await self.session._cmd_pool_size(**kwargs)) | ||||
|         mock__write_function_output.assert_awaited_once_with( | ||||
|             self.mock_pool.__class__.pool_size.fget, self.session._pool | ||||
|         ) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_num_running(self, mock__write_function_output: AsyncMock): | ||||
|         self.assertIsNone(await self.session._cmd_num_running()) | ||||
|         mock__write_function_output.assert_awaited_once_with( | ||||
|             self.mock_pool.__class__.num_running.fget, self.session._pool | ||||
|         ) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_start(self, mock__write_function_output: AsyncMock): | ||||
|         num = 12345 | ||||
|         kwargs = {session.NUM: num, FOO: BAR} | ||||
|         self.assertIsNone(await self.session._cmd_start(**kwargs)) | ||||
|         mock__write_function_output.assert_awaited_once_with(self.mock_pool.start, num) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_stop(self, mock__write_function_output: AsyncMock): | ||||
|         num = 12345 | ||||
|         kwargs = {session.NUM: num, FOO: BAR} | ||||
|         self.assertIsNone(await self.session._cmd_stop(**kwargs)) | ||||
|         mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop, num) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_stop_all(self, mock__write_function_output: AsyncMock): | ||||
|         self.assertIsNone(await self.session._cmd_stop_all()) | ||||
|         mock__write_function_output.assert_awaited_once_with(self.mock_pool.stop_all) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_write_function_output') | ||||
|     async def test__cmd_func_name(self, mock__write_function_output: AsyncMock): | ||||
|         self.assertIsNone(await self.session._cmd_func_name()) | ||||
|         mock__write_function_output.assert_awaited_once_with( | ||||
|             self.mock_pool.__class__.func_name.fget, self.session._pool | ||||
|         ) | ||||
|  | ||||
|     async def test__execute_command(self): | ||||
|         mock_method = AsyncMock() | ||||
|         cmd = 'this-is-a-test' | ||||
|         setattr(self.session, '_cmd_' + cmd.replace('-', '_'), mock_method) | ||||
|         kwargs = {FOO: BAR, 'hello': 'python'} | ||||
|         self.assertIsNone(await self.session._execute_command(**{CMD.CMD: cmd}, **kwargs)) | ||||
|         mock_method.assert_awaited_once_with(**kwargs) | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_execute_command') | ||||
|     async def test__parse_command(self, mock__execute_command: AsyncMock): | ||||
|         msg = 'asdf asd as a' | ||||
|         kwargs = {FOO: BAR, 'hello': 'python'} | ||||
|         mock_parse_args = MagicMock(return_value=Namespace(**kwargs)) | ||||
|         self.session._parser = MagicMock(parse_args=mock_parse_args) | ||||
|         self.mock_writer.write = MagicMock() | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_not_called() | ||||
|         mock__execute_command.assert_awaited_once_with(**kwargs) | ||||
|  | ||||
|         mock__execute_command.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|  | ||||
|         mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_called_once_with(str(exc).encode()) | ||||
|         mock__execute_command.assert_not_awaited() | ||||
|  | ||||
|         self.mock_writer.write.reset_mock() | ||||
|         mock_parse_args.reset_mock() | ||||
|  | ||||
|         mock_parse_args.side_effect = HelpRequested() | ||||
|         self.assertIsNone(await self.session._parse_command(msg)) | ||||
|         mock_parse_args.assert_called_once_with(msg.split(' ')) | ||||
|         self.mock_writer.write.assert_not_called() | ||||
|         mock__execute_command.assert_not_awaited() | ||||
|  | ||||
|     @patch.object(session.ControlSession, '_parse_command') | ||||
|     async def test_listen(self, mock__parse_command: AsyncMock): | ||||
|         def make_reader_return_empty(): | ||||
|             self.mock_reader.read.return_value = b'' | ||||
|         self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) | ||||
|         msg = "fascinating" | ||||
|         self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) | ||||
|         self.assertIsNone(await self.session.listen()) | ||||
|         self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) | ||||
|         mock__parse_command.assert_awaited_once_with(msg) | ||||
|         self.mock_writer.drain.assert_awaited_once_with() | ||||
|  | ||||
|         self.mock_reader.read.reset_mock() | ||||
|         mock__parse_command.reset_mock() | ||||
|         self.mock_writer.drain.reset_mock() | ||||
|  | ||||
|         self.mock_server.is_serving = MagicMock(return_value=False) | ||||
|         self.assertIsNone(await self.session.listen()) | ||||
|         self.mock_reader.read.assert_not_awaited() | ||||
|         mock__parse_command.assert_not_awaited() | ||||
|         self.mock_writer.drain.assert_not_awaited() | ||||
| @@ -1,134 +0,0 @@ | ||||
| __author__ = "Daniil Fajnberg" | ||||
| __copyright__ = "Copyright © 2022 Daniil Fajnberg" | ||||
| __license__ = """GNU LGPLv3.0 | ||||
|  | ||||
| This file is part of asyncio-taskpool. | ||||
|  | ||||
| asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of | ||||
| version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation. | ||||
|  | ||||
| asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;  | ||||
| without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  | ||||
| See the GNU Lesser General Public License for more details. | ||||
|  | ||||
| You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.  | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Unittests for the `asyncio_taskpool.session_parser` module. | ||||
| """ | ||||
|  | ||||
|  | ||||
| from argparse import Action, ArgumentParser, HelpFormatter, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter | ||||
| from unittest import IsolatedAsyncioTestCase | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| from asyncio_taskpool import session_parser | ||||
| from asyncio_taskpool.constants import SESSION_WRITER, CLIENT_INFO | ||||
| from asyncio_taskpool.exceptions import HelpRequested | ||||
|  | ||||
|  | ||||
| FOO = 'foo' | ||||
|  | ||||
|  | ||||
| class ControlServerTestCase(IsolatedAsyncioTestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.help_formatter_factory_patcher = patch.object(session_parser.CommandParser, 'help_formatter_factory') | ||||
|         self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() | ||||
|         self.mock_help_formatter_factory.return_value = RawTextHelpFormatter | ||||
|         self.session_writer, self.terminal_width = MagicMock(), 420 | ||||
|         self.kwargs = { | ||||
|             SESSION_WRITER: self.session_writer, | ||||
|             CLIENT_INFO.TERMINAL_WIDTH: self.terminal_width, | ||||
|             session_parser.FORMATTER_CLASS: FOO | ||||
|         } | ||||
|         self.parser = session_parser.CommandParser(**self.kwargs) | ||||
|  | ||||
|     def tearDown(self) -> None: | ||||
|         self.help_formatter_factory_patcher.stop() | ||||
|  | ||||
|     def test_help_formatter_factory(self): | ||||
|         self.help_formatter_factory_patcher.stop() | ||||
|  | ||||
|         class MockBaseClass(HelpFormatter): | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 super().__init__(*args, **kwargs) | ||||
|  | ||||
|         terminal_width = 123456789 | ||||
|         cls = session_parser.CommandParser.help_formatter_factory(terminal_width, MockBaseClass) | ||||
|         self.assertTrue(issubclass(cls, MockBaseClass)) | ||||
|         instance = cls('prog') | ||||
|         self.assertEqual(terminal_width, getattr(instance, '_width')) | ||||
|  | ||||
|         cls = session_parser.CommandParser.help_formatter_factory(terminal_width) | ||||
|         self.assertTrue(issubclass(cls, ArgumentDefaultsHelpFormatter)) | ||||
|         instance = cls('prog') | ||||
|         self.assertEqual(terminal_width, getattr(instance, '_width')) | ||||
|  | ||||
|     def test_init(self): | ||||
|         self.assertIsInstance(self.parser, ArgumentParser) | ||||
|         self.assertEqual(self.session_writer, self.parser._session_writer) | ||||
|         self.assertEqual(self.terminal_width, self.parser._terminal_width) | ||||
|         self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) | ||||
|         self.assertFalse(getattr(self.parser, 'exit_on_error')) | ||||
|         self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) | ||||
|  | ||||
|     def test_session_writer(self): | ||||
|         self.assertEqual(self.session_writer, self.parser.session_writer) | ||||
|  | ||||
|     def test_terminal_width(self): | ||||
|         self.assertEqual(self.terminal_width, self.parser.terminal_width) | ||||
|  | ||||
|     def test__print_message(self): | ||||
|         self.session_writer.write = MagicMock() | ||||
|         self.assertIsNone(self.parser._print_message('')) | ||||
|         self.session_writer.write.assert_not_called() | ||||
|         msg = 'foo bar baz' | ||||
|         self.assertIsNone(self.parser._print_message(msg)) | ||||
|         self.session_writer.write.assert_called_once_with(msg.encode()) | ||||
|  | ||||
|     @patch.object(session_parser.CommandParser, '_print_message') | ||||
|     def test_exit(self, mock__print_message: MagicMock): | ||||
|         self.assertIsNone(self.parser.exit(123, '')) | ||||
|         mock__print_message.assert_not_called() | ||||
|         msg = 'foo bar baz' | ||||
|         self.assertIsNone(self.parser.exit(123, msg)) | ||||
|         mock__print_message.assert_called_once_with(msg) | ||||
|  | ||||
|     @patch.object(session_parser.ArgumentParser, 'print_help') | ||||
|     def test_print_help(self, mock_print_help: MagicMock): | ||||
|         arg = MagicMock() | ||||
|         with self.assertRaises(HelpRequested): | ||||
|             self.parser.print_help(arg) | ||||
|         mock_print_help.assert_called_once_with(arg) | ||||
|  | ||||
|     def test_add_optional_num_argument(self): | ||||
|         metavar = 'FOOBAR' | ||||
|         action = self.parser.add_optional_num_argument(metavar=metavar) | ||||
|         self.assertIsInstance(action, Action) | ||||
|         self.assertEqual('?', action.nargs) | ||||
|         self.assertEqual(1, action.default) | ||||
|         self.assertEqual(int, action.type) | ||||
|         self.assertEqual(metavar, action.metavar) | ||||
|         num = 111 | ||||
|         kwargs = vars(self.parser.parse_args([f'{num}'])) | ||||
|         self.assertDictEqual({session_parser.NUM: num}, kwargs) | ||||
|  | ||||
|         name = f'--{FOO}' | ||||
|         nargs = '+' | ||||
|         default = 1 | ||||
|         _type = float | ||||
|         required = True | ||||
|         dest = 'foo_bar' | ||||
|         action = self.parser.add_optional_num_argument(name, nargs=nargs, default=default, type=_type, | ||||
|                                                        required=required, metavar=metavar, dest=dest) | ||||
|         self.assertIsInstance(action, Action) | ||||
|         self.assertEqual(nargs, action.nargs) | ||||
|         self.assertEqual(default, action.default) | ||||
|         self.assertEqual(_type, action.type) | ||||
|         self.assertEqual(required, action.required) | ||||
|         self.assertEqual(metavar, action.metavar) | ||||
|         self.assertEqual(dest, action.dest) | ||||
|         kwargs = vars(self.parser.parse_args([f'{num}', name, '1', '1.5'])) | ||||
|         self.assertDictEqual({session_parser.NUM: num, dest: [1.0, 1.5]}, kwargs) | ||||
							
								
								
									
										275
									
								
								usage/USAGE.md
									
									
									
									
									
								
							
							
						
						
									
										275
									
								
								usage/USAGE.md
									
									
									
									
									
								
							| @@ -1,12 +1,18 @@ | ||||
| # Using `asyncio-taskpool` | ||||
|  | ||||
| ## Contents | ||||
| - [Contents](#contents) | ||||
| - [Minimal example for `SimpleTaskPool`](#minimal-example-for-simpletaskpool) | ||||
| - [Advanced example for `TaskPool`](#advanced-example-for-taskpool) | ||||
| - [Control server example](#control-server-example) | ||||
|  | ||||
| ## Minimal example for `SimpleTaskPool` | ||||
|  | ||||
| With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run. | ||||
|  | ||||
| The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.  | ||||
|  | ||||
| The following demo code enables full log output first for additional clarity. It is complete and should work as is. | ||||
|  | ||||
| ### Code | ||||
| The following demo script enables full log output first for additional clarity. It is complete and should work as is. | ||||
|  | ||||
| ```python | ||||
| import logging | ||||
| @@ -28,54 +34,56 @@ async def work(n: int) -> None: | ||||
|     """ | ||||
|     for i in range(n): | ||||
|         await asyncio.sleep(1) | ||||
|         print("did", i) | ||||
|         print("> did", i) | ||||
|  | ||||
|  | ||||
| async def main() -> None: | ||||
|     pool = SimpleTaskPool(work, (5,))  # initializes the pool; no work is being done yet | ||||
|     await pool.start(3)  # launches work tasks 0, 1, and 2 | ||||
|     pool = SimpleTaskPool(work, args=(5,))  # initializes the pool; no work is being done yet | ||||
|     pool.start(3)  # launches work tasks 0, 1, and 2 | ||||
|     await asyncio.sleep(1.5)  # lets the tasks work for a bit | ||||
|     await pool.start()  # launches work task 3 | ||||
|     pool.start(1)  # launches work task 3 | ||||
|     await asyncio.sleep(1.5)  # lets the tasks work for a bit | ||||
|     pool.stop(2)  # cancels tasks 3 and 2 | ||||
|     pool.lock()  # required for the last line | ||||
|     await pool.gather()  # awaits all tasks, then flushes the pool | ||||
|     pool.stop(2)  # cancels tasks 3 and 2 (LIFO order) | ||||
|     await pool.gather_and_close()  # awaits all tasks, then flushes the pool | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     asyncio.run(main()) | ||||
| ``` | ||||
|  | ||||
| ### Output  | ||||
| <details> | ||||
| <summary>Output: (Click to expand)</summary> | ||||
|  | ||||
| ``` | ||||
| SimpleTaskPool-0 initialized | ||||
| Started SimpleTaskPool-0_Task-0 | ||||
| Started SimpleTaskPool-0_Task-1 | ||||
| Started SimpleTaskPool-0_Task-2 | ||||
| did 0 | ||||
| did 0 | ||||
| did 0 | ||||
| > did 0 | ||||
| > did 0 | ||||
| > did 0 | ||||
| Started SimpleTaskPool-0_Task-3 | ||||
| did 1 | ||||
| did 1 | ||||
| did 1 | ||||
| did 0 | ||||
| > did 1 | ||||
| > did 1 | ||||
| > did 1 | ||||
| > did 0 | ||||
| > did 2 | ||||
| > did 2 | ||||
| SimpleTaskPool-0 is locked! | ||||
| Cancelling SimpleTaskPool-0_Task-3 ... | ||||
| Cancelled SimpleTaskPool-0_Task-3 | ||||
| Ended SimpleTaskPool-0_Task-3 | ||||
| Cancelling SimpleTaskPool-0_Task-2 ... | ||||
| Cancelled SimpleTaskPool-0_Task-2 | ||||
| Ended SimpleTaskPool-0_Task-2 | ||||
| did 2 | ||||
| did 2 | ||||
| did 3 | ||||
| did 3 | ||||
| Cancelling SimpleTaskPool-0_Task-3 ... | ||||
| Cancelled SimpleTaskPool-0_Task-3 | ||||
| Ended SimpleTaskPool-0_Task-3 | ||||
| > did 3 | ||||
| > did 3 | ||||
| Ended SimpleTaskPool-0_Task-0 | ||||
| Ended SimpleTaskPool-0_Task-1 | ||||
| did 4 | ||||
| did 4 | ||||
| > did 4 | ||||
| > did 4 | ||||
| ``` | ||||
| </details> | ||||
|  | ||||
| ## Advanced example for `TaskPool` | ||||
|  | ||||
| @@ -83,9 +91,7 @@ This time, we want to start tasks from _different_ coroutine functions **and** w | ||||
|  | ||||
| As with the simple example, we need "worker" coroutine functions that can do something asynchronously, as well as a main coroutine function that sets up the pool, starts the tasks, and eventually awaits them. | ||||
|  | ||||
| The following demo code enables full log output first for additional clarity. It is complete and should work as is. | ||||
|  | ||||
| ### Code | ||||
| The following demo script enables full log output first for additional clarity. It is complete and should work as is. | ||||
|  | ||||
| ```python | ||||
| import logging | ||||
| @@ -101,133 +107,160 @@ async def work(start: int, stop: int, step: int = 1) -> None: | ||||
|     """Pseudo-worker function counting through a range with a second of sleep in between each iteration.""" | ||||
|     for i in range(start, stop, step): | ||||
|         await asyncio.sleep(1) | ||||
|         print("work with", i) | ||||
|         print("> work with", i) | ||||
|  | ||||
|  | ||||
| async def other_work(a: int, b: int) -> None: | ||||
|     """Different pseudo-worker counting through a range with half a second of sleep in between each iteration.""" | ||||
|     for i in range(a, b): | ||||
|         await asyncio.sleep(0.5) | ||||
|         print("other_work with", i) | ||||
|         print("> other_work with", i) | ||||
|  | ||||
|  | ||||
| async def main() -> None: | ||||
|     # Initialize a new task pool instance and limit its size to 3 tasks. | ||||
|     pool = TaskPool(3) | ||||
|     # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments). | ||||
|     print("Called `apply`") | ||||
|     await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) | ||||
|     # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments). | ||||
|     print("> Called `apply`") | ||||
|     pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) | ||||
|     # Let the tasks work for a bit. | ||||
|     await asyncio.sleep(1.5) | ||||
|     # Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different  | ||||
|     # positional arguments by using `starmap`, but have **no more than two of those** run concurrently. | ||||
|     # positional arguments by using `starmap`, but we want no more than two of those to run concurrently. | ||||
|     # Since we set our pool size to 3, and already have two tasks working within the pool, | ||||
|     # only the first one of these will start immediately (and receive ID 2). | ||||
|     # The second one will start (with ID 3), only once there is room in the pool, | ||||
|     # which -- in this example -- will be the case after ID 2 ends; | ||||
|     # until then the `starmap` method call **will block**! | ||||
|     # which -- in this example -- will be the case after ID 2 ends. | ||||
|     # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) | ||||
|     # **only** once there is room in the pool **and** no more than one of these last four tasks is running. | ||||
|     # only once there is room in the pool and no more than one other task of these new ones is running. | ||||
|     args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] | ||||
|     print("Calling `starmap`...") | ||||
|     await pool.starmap(other_work, args_list, group_size=2) | ||||
|     print("`starmap` returned") | ||||
|     # Now we lock the pool, so that we can safely await all our tasks. | ||||
|     pool.lock() | ||||
|     # Finally, we block, until all tasks have ended. | ||||
|     print("Called `gather`") | ||||
|     await pool.gather() | ||||
|     print("Done.") | ||||
|     pool.starmap(other_work, args_list, num_concurrent=2) | ||||
|     print("> Called `starmap`") | ||||
|     # We block, until all tasks have ended. | ||||
|     print("> Calling `gather_and_close`...") | ||||
|     await pool.gather_and_close() | ||||
|     print("> Done.") | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     asyncio.run(main()) | ||||
| ``` | ||||
|  | ||||
| ### Output  | ||||
| Additional comments for the output are provided with `<---` next to the output lines. | ||||
| <details> | ||||
| <summary>Output: (Click to expand)</summary> | ||||
|  | ||||
| (Keep in mind that the logger and `print` asynchronously write to `stdout`.) | ||||
| ``` | ||||
| TaskPool-0 initialized | ||||
| Started TaskPool-0_Task-0 | ||||
| Started TaskPool-0_Task-1 | ||||
| Called `apply` | ||||
| work with 100 | ||||
| work with 100 | ||||
| Calling `starmap`...    <--- notice that this blocks as expected | ||||
| Started TaskPool-0_Task-2 | ||||
| work with 110 | ||||
| work with 110 | ||||
| other_work with 0 | ||||
| other_work with 1 | ||||
| work with 120 | ||||
| work with 120 | ||||
| other_work with 2 | ||||
| other_work with 3 | ||||
| work with 130 | ||||
| work with 130 | ||||
| other_work with 4 | ||||
| other_work with 5 | ||||
| work with 140 | ||||
| work with 140 | ||||
| other_work with 6 | ||||
| other_work with 7 | ||||
| work with 150 | ||||
| work with 150 | ||||
| other_work with 8 | ||||
| Ended TaskPool-0_Task-2    <--- here Task-2 makes room in the pool and unblocks `main()` | ||||
| > Called `apply` | ||||
| > work with 100 | ||||
| > work with 100 | ||||
| > Called `starmap`   <--- notice that this immediately returns, even before Task-2 is started | ||||
| > Calling `gather_and_close`...    <--- this blocks `main()` until all tasks have ended | ||||
| TaskPool-0 is locked! | ||||
| Started TaskPool-0_Task-2    <--- at this point the pool is full | ||||
| > work with 110 | ||||
| > work with 110 | ||||
| > other_work with 0 | ||||
| > other_work with 1 | ||||
| > work with 120 | ||||
| > work with 120 | ||||
| > other_work with 2 | ||||
| > other_work with 3 | ||||
| > work with 130 | ||||
| > work with 130 | ||||
| > other_work with 4 | ||||
| > other_work with 5 | ||||
| > work with 140 | ||||
| > work with 140 | ||||
| > other_work with 6 | ||||
| > other_work with 7 | ||||
| > work with 150 | ||||
| > work with 150 | ||||
| > other_work with 8 | ||||
| Ended TaskPool-0_Task-2    <--- this frees up room for one more task from `starmap` | ||||
| Started TaskPool-0_Task-3 | ||||
| other_work with 9 | ||||
| `starmap` returned | ||||
| Called `gather`    <--- now this will block `main()` until all tasks have ended | ||||
| work with 160 | ||||
| work with 160 | ||||
| other_work with 10 | ||||
| other_work with 11 | ||||
| work with 170 | ||||
| work with 170 | ||||
| other_work with 12 | ||||
| other_work with 13 | ||||
| work with 180 | ||||
| work with 180 | ||||
| other_work with 14 | ||||
| other_work with 15 | ||||
| > other_work with 9 | ||||
| > work with 160 | ||||
| > work with 160 | ||||
| > other_work with 10 | ||||
| > other_work with 11 | ||||
| > work with 170 | ||||
| > work with 170 | ||||
| > other_work with 12 | ||||
| > other_work with 13 | ||||
| > work with 180 | ||||
| > work with 180 | ||||
| > other_work with 14 | ||||
| > other_work with 15 | ||||
| Ended TaskPool-0_Task-0 | ||||
| Ended TaskPool-0_Task-1    <--- even though there is room in the pool now, Task-5 will not start | ||||
| Started TaskPool-0_Task-4 | ||||
| work with 190 | ||||
| work with 190 | ||||
| other_work with 16 | ||||
| other_work with 20 | ||||
| other_work with 17 | ||||
| other_work with 21 | ||||
| other_work with 18 | ||||
| other_work with 22 | ||||
| other_work with 19 | ||||
| Ended TaskPool-0_Task-3    <--- now that only Task-4 is left, Task-5 will start | ||||
| Ended TaskPool-0_Task-1    <--- these two end and free up two more slots in the pool | ||||
| Started TaskPool-0_Task-4    <--- since `num_concurrent` is set to 2, Task-5 will not start | ||||
| > work with 190 | ||||
| > work with 190 | ||||
| > other_work with 16 | ||||
| > other_work with 17 | ||||
| > other_work with 20 | ||||
| > other_work with 18 | ||||
| > other_work with 21 | ||||
| Ended TaskPool-0_Task-3    <--- now that only Task-4 of the group remains, Task-5 starts | ||||
| Started TaskPool-0_Task-5 | ||||
| other_work with 23 | ||||
| other_work with 30 | ||||
| other_work with 24 | ||||
| other_work with 31 | ||||
| other_work with 25 | ||||
| other_work with 32 | ||||
| other_work with 26 | ||||
| other_work with 33 | ||||
| other_work with 27 | ||||
| other_work with 34 | ||||
| other_work with 28 | ||||
| other_work with 35 | ||||
| > other_work with 19 | ||||
| > other_work with 22 | ||||
| > other_work with 23 | ||||
| > other_work with 30 | ||||
| > other_work with 24 | ||||
| > other_work with 31 | ||||
| > other_work with 25 | ||||
| > other_work with 32 | ||||
| > other_work with 26 | ||||
| > other_work with 33 | ||||
| > other_work with 27 | ||||
| > other_work with 34 | ||||
| > other_work with 28 | ||||
| > other_work with 35 | ||||
| > other_work with 29 | ||||
| > other_work with 36 | ||||
| Ended TaskPool-0_Task-4 | ||||
| other_work with 29 | ||||
| other_work with 36 | ||||
| other_work with 37 | ||||
| other_work with 38 | ||||
| other_work with 39 | ||||
| Done. | ||||
| > other_work with 37 | ||||
| > other_work with 38 | ||||
| > other_work with 39 | ||||
| Ended TaskPool-0_Task-5 | ||||
| > Done. | ||||
| ``` | ||||
|  | ||||
| (Added comments with `<---` next to the output lines.) | ||||
|  | ||||
| Keep in mind that the logger and `print` asynchronously write to `stdout`, so the order of lines in your output may be slightly different. | ||||
| </details> | ||||
|  | ||||
| ## Control server example | ||||
|  | ||||
| One of the main features of `asyncio-taskpool` is the ability to control a task pool "from the outside" at runtime. | ||||
|  | ||||
| The [example_server.py](./example_server.py) script launches a couple of worker tasks within a `SimpleTaskPool` instance and then starts a `TCPControlServer` instance for that task pool. The server is configured to locally bind to port `9999` and is stopped automatically after the "work" is done. | ||||
|  | ||||
| To run the script: | ||||
| ```shell | ||||
| python usage/example_server.py | ||||
| ``` | ||||
|  | ||||
| You can then connect to the server via the command line interface: | ||||
| ```shell | ||||
| python -m asyncio_taskpool.control tcp localhost 9999 | ||||
| ``` | ||||
|  | ||||
| The CLI starts a `TCPControlClient` that connects to our example server. Once the connection is established, it gives you an input prompt allowing you to issue commands to the task pool: | ||||
| ``` | ||||
| Connected to SimpleTaskPool-0 | ||||
| Type '-h' to get help and usage instructions for all available commands. | ||||
|  | ||||
| > | ||||
| ``` | ||||
|  | ||||
| It may be useful to run the server script and the client interface in two separate terminal windows side by side. The server script is configured with a verbose logger and will react to any commands issued by the client with detailed log messages in the terminal. | ||||
|  | ||||
| --- | ||||
|  | ||||
| © 2022 Daniil Fajnberg | ||||
|   | ||||
| @@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w | ||||
| If not, see <https://www.gnu.org/licenses/>.""" | ||||
|  | ||||
| __doc__ = """ | ||||
| Working example of a UnixControlServer in combination with the SimpleTaskPool. | ||||
| Working example of a TCPControlServer in combination with the SimpleTaskPool. | ||||
| Use the main CLI client to interface at the socket. | ||||
| """ | ||||
|  | ||||
| @@ -23,8 +23,9 @@ Use the main CLI client to interface at the socket. | ||||
| import asyncio | ||||
| import logging | ||||
|  | ||||
| from asyncio_taskpool import SimpleTaskPool, UnixControlServer | ||||
| from asyncio_taskpool.constants import PACKAGE_NAME | ||||
| from asyncio_taskpool import SimpleTaskPool | ||||
| from asyncio_taskpool.control import TCPControlServer | ||||
| from asyncio_taskpool.internals.constants import PACKAGE_NAME | ||||
|  | ||||
|  | ||||
| logging.getLogger().setLevel(logging.NOTSET) | ||||
| @@ -34,11 +35,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler()) | ||||
| async def work(item: int) -> None: | ||||
|     """The non-blocking sleep simulates something like an I/O operation that can be done asynchronously.""" | ||||
|     await asyncio.sleep(1) | ||||
|     print("worked on", item) | ||||
|     print("worked on", item, flush=True) | ||||
|  | ||||
|  | ||||
| async def worker(q: asyncio.Queue) -> None: | ||||
|     """Simulates doing asynchronous work that takes a little bit of time to finish.""" | ||||
|     """Simulates doing asynchronous work that takes a bit of time to finish.""" | ||||
|     # We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop. | ||||
|     while True: | ||||
|         # We want to block here, until we can get the next item from the queue. | ||||
| @@ -65,21 +66,20 @@ async def main() -> None: | ||||
|     # We just put some integers into our queue, since all our workers actually do, is print an item and sleep for a bit. | ||||
|     for item in range(100): | ||||
|         q.put_nowait(item) | ||||
|     pool = SimpleTaskPool(worker, (q,))  # initializes the pool | ||||
|     await pool.start(3)  # launches three worker tasks | ||||
|     control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever() | ||||
|     pool = SimpleTaskPool(worker, args=(q,))  # initializes the pool | ||||
|     pool.start(3)  # launches three worker tasks | ||||
|     control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever() | ||||
|     # We block until `.task_done()` has been called once by our workers for every item placed into the queue. | ||||
|     await q.join() | ||||
|     # Since we don't need any "work" done anymore, we can lock our control server by cancelling the task. | ||||
|     # Since we don't need any "work" done anymore, we can get rid of our control server by cancelling the task. | ||||
|     control_server_task.cancel() | ||||
|     # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, | ||||
|     # we can now safely cancel their tasks. | ||||
|     pool.stop_all() | ||||
|     pool.lock() | ||||
|     # Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled. | ||||
|     # Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled. | ||||
|     # We block until they all return or raise an exception, but since we are not interested in any of their exceptions, | ||||
|     # we just silently collect their exceptions along with their return values. | ||||
|     await pool.gather(return_exceptions=True) | ||||
|     await pool.gather_and_close(return_exceptions=True) | ||||
|     await control_server_task | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user