Compare commits

..

33 Commits

Author SHA1 Message Date
73aa93a9b7
Fix unit tests 2022-05-08 15:26:17 +02:00
051d0cb911
Fix control server response writes 2022-05-08 15:20:07 +02:00
ee0b8c0002
End control server response with newline 2022-05-07 16:16:50 +02:00
28c997e0ee
Fix client and session tests 2022-05-07 14:54:33 +02:00
5a72a6d1d1
Fix server's control session reading from socket; improve docstrings 2022-05-07 14:44:43 +02:00
72e380cd77
Version increment 2022-05-06 18:57:45 +02:00
85672bddeb
Add short docstring to until_closed method 2022-05-05 08:46:27 +02:00
dae883446a
Add until_closed method to pools 2022-05-05 08:35:32 +02:00
a4ecf39157
Fix socket path bug in control client 2022-04-27 17:02:34 +02:00
e3bbb05eac
Fix version number 2022-04-18 13:49:40 +02:00
36527ccffc
Improve exception handling in _start_num 2022-04-18 13:45:15 +02:00
d047b99119
Fix cancel message bug for Python 3.8; test coverage workaround for Python version conditions 2022-04-10 10:43:53 +02:00
d7cd16c540
Add docs on IDs and groups 2022-04-09 14:53:07 +02:00
db306a1a1f
Fixed Python 3.8 compatibility bugs; classmethod+property workaround; control session buffer 2022-04-08 11:53:53 +02:00
3a8fcb2d5a
Optimize coverage script/settings 2022-04-06 21:47:39 +02:00
cf02206588
Add RtD config 2022-04-04 10:06:39 +02:00
0796038dcd
Update Readme 2022-04-03 16:52:51 +02:00
f4e33baf82
Add workflows & badges 2022-04-03 16:33:28 +02:00
9b838b6130
improved exception handling/logging in _map 2022-03-31 21:05:01 +02:00
0daed04167
improved exception handling/logging in apply meta task 2022-03-31 11:04:31 +02:00
80fc91ec47
made start "non-async" using meta task 2022-03-30 21:51:19 +02:00
a72a7cc516
made a bunch of methods "non-async" 2022-03-30 18:16:28 +02:00
91d546ebc2
moved the entire meta task logic to the base class 2022-03-30 16:17:34 +02:00
5b3ac52bf6
start method now returns group name 2022-03-30 15:31:38 +02:00
82e6ca7b1a
task group naming logic changed 2022-03-30 12:37:32 +02:00
153127e028
lock before gathering meta tasks 2022-03-30 11:47:15 +02:00
17539e9c27
made apply non-blocking by using a meta-task 2022-03-29 20:01:44 +02:00
1beb9fc9b0
gather_and_close now automatically locks the pool 2022-03-29 19:43:21 +02:00
23a4cb028a
renamed group_size to num_concurrent in _map 2022-03-29 19:34:27 +02:00
54e5bfa8a0
drastically simplified meta-task internals 2022-03-29 19:34:13 +02:00
0e7e92a91b
better error handling when converting arguments 2022-03-26 10:29:34 +01:00
a9011076c4
fixed potential race cond. gathering meta tasks 2022-03-25 12:58:18 +01:00
7e34aa106d
sphinx documentation; adjusted all docstrings; moved some modules to non-public subpackage 2022-03-24 13:38:30 +01:00
53 changed files with 2132 additions and 985 deletions

View File

@ -1,14 +1,12 @@
[run] [run]
source = src/ source = src/
branch = true branch = true
omit = command_line = -m unittest discover
.venv/*
[report] [report]
fail_under = 100
show_missing = True show_missing = True
skip_covered = False skip_covered = False
exclude_lines = exclude_lines =
if TYPE_CHECKING: if TYPE_CHECKING:
if __name__ == ['"]__main__['"]: if __name__ == ['"]__main__['"]:
omit =
tests/*

88
.github/workflows/main.yaml vendored Normal file
View File

@ -0,0 +1,88 @@
name: CI
on:
push:
branches: [master]
jobs:
tests:
name: Python ${{ matrix.python-version }} Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- '3.8'
- '3.9'
- '3.10'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: 'requirements/dev.txt'
- name: Upgrade packaging tools
run: pip install -U pip
- name: Install dependencies
run: pip install -U -r requirements/dev.txt
- name: Install asyncio-taskpool
run: pip install -e .
- name: Run tests for Python ${{ matrix.python-version }}
if: ${{ matrix.python-version != '3.10' }}
run: python -m tests
- name: Run tests for Python 3.10 and save coverage
if: ${{ matrix.python-version == '3.10' }}
run: echo "coverage=$(./coverage.sh)" >> $GITHUB_ENV
outputs:
coverage: ${{ env.coverage }}
update_badges:
needs: tests
name: Update Badges
env:
meta_gist_id: 3f8240a976e8781a765d9c74a583dcda
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Download `cloc`
run: sudo apt-get update -y && sudo apt-get install -y cloc
- name: Count lines of code/comments
run: |
echo "cloc_code=$(./cloc.sh -c src/)" >> $GITHUB_ENV
echo "cloc_comments=$(./cloc.sh -m src/)" >> $GITHUB_ENV
echo "cloc_commentpercent=$(./cloc.sh -p src/)" >> $GITHUB_ENV
- name: Create badge for lines of code
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-code.json
label: Lines of Code
message: ${{ env.cloc_code }}
- name: Create badge for lines of comments
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: cloc-comments.json
label: Comments
message: ${{ env.cloc_comments }} (${{ env.cloc_commentpercent }}%)
- name: Create badge for test coverage
uses: Schneegans/dynamic-badges-action@v1.2.0
with:
auth: ${{ secrets.GIST_META_DATA }}
gistID: ${{ env.meta_gist_id }}
filename: test-coverage.json
label: Coverage
message: ${{ needs.tests.outputs.coverage }}

3
.gitignore vendored
View File

@ -3,9 +3,10 @@
# IDE settings: # IDE settings:
/.idea/ /.idea/
/.vscode/ /.vscode/
# Distribution / packaging: # Distribution / build files:
*.egg-info/ *.egg-info/
/dist/ /dist/
/docs/build/
# Python cache: # Python cache:
__pycache__/ __pycache__/
# Testing: # Testing:

11
.readthedocs.yaml Normal file
View File

@ -0,0 +1,11 @@
version: 2
build:
os: 'ubuntu-20.04'
tools:
python: '3.8'
python:
install:
- method: pip
path: .
sphinx:
fail_on_warning: true

View File

@ -1,7 +1,30 @@
[//]: # (This file is part of asyncio-taskpool.)
[//]: # (asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of)
[//]: # (version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.)
[//]: # (asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;)
[//]: # (without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.)
[//]: # (See the GNU Lesser General Public License for more details.)
[//]: # (You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.)
[//]: # (If not, see <https://www.gnu.org/licenses/>.)
# asyncio-taskpool # asyncio-taskpool
[![GitHub last commit][github-last-commit-img]][github-last-commit]
![Lines of code][gist-cloc-code-img]
![Lines of comments][gist-cloc-comments-img]
![Test coverage][gist-test-coverage-img]
[![License: LGPL v3.0][lgpl3-img]][lgpl3]
[![PyPI version][pypi-latest-version-img]][pypi-latest-version]
**Dynamically manage pools of asyncio tasks** **Dynamically manage pools of asyncio tasks**
Full documentation available at [RtD](https://asyncio-taskpool.readthedocs.io/en/latest).
---
## Contents ## Contents
- [Contents](#contents) - [Contents](#contents)
- [Summary](#summary) - [Summary](#summary)
@ -27,25 +50,16 @@ Generally speaking, a task is added to a pool by providing it with a coroutine f
```python ```python
from asyncio_taskpool import SimpleTaskPool from asyncio_taskpool import SimpleTaskPool
... ...
async def work(_foo, _bar): ... async def work(_foo, _bar): ...
...
async def main(): async def main():
pool = SimpleTaskPool(work, args=('xyz', 420)) pool = SimpleTaskPool(work, args=('xyz', 420))
await pool.start(5) pool.start(5)
... ...
pool.stop(3) pool.stop(3)
... ...
pool.lock()
await pool.gather_and_close() await pool.gather_and_close()
...
``` ```
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.) Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
@ -64,8 +78,7 @@ Python Version 3.8+, tested on Linux
## Testing ## Testing
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`. Install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`, then execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and save the coverage report.
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
## License ## License
@ -76,3 +89,13 @@ The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COP
--- ---
© 2022 Daniil Fajnberg © 2022 Daniil Fajnberg
[github-last-commit]: https://github.com/daniil-berg/asyncio-taskpool/commits
[github-last-commit-img]: https://img.shields.io/github/last-commit/daniil-berg/asyncio-taskpool?label=Last%20commit&logo=git&
[gist-cloc-code-img]: https://img.shields.io/endpoint?logo=python&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-code.json
[gist-cloc-comments-img]: https://img.shields.io/endpoint?logo=sharp&color=lightgrey&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-comments.json
[gist-test-coverage-img]: https://img.shields.io/endpoint?logo=pytest&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/test-coverage.json
[lgpl3]: https://www.gnu.org/licenses/lgpl-3.0
[lgpl3-img]: https://img.shields.io/badge/License-LGPL_v3.0-darkgreen.svg?logo=gnu
[pypi-latest-version-img]: https://img.shields.io/pypi/v/asyncio-taskpool?color=teal&logo=pypi
[pypi-latest-version]: https://pypi.org/project/asyncio-taskpool/

46
cloc.sh Executable file
View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
# This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
typeset option
if getopts 'bcmp' option; then
if [[ ${option} == [bcmp] ]]; then
shift
else
echo >&2 "Invalid option '$1' provided"
exit 1
fi
fi
typeset source=$1
if [[ -z ${source} ]]; then
echo >&2 Source file/directory missing
exit 1
fi
typeset blank code comment commentpercent
read blank comment code commentpercent < <( \
cloc --csv --quiet --hide-rate --include-lang Python ${source} |
awk -F, '$2 == "SUM" {printf ("%d %d %d %1.0f", $3, $4, $5, 100 * $4 / ($5 + $4)); exit}'
)
case ${option} in
b) echo ${blank} ;;
c) echo ${code} ;;
m) echo ${comment} ;;
p) echo ${commentpercent} ;;
*) echo Blank lines: ${blank}
echo Lines of comments: ${comment}
echo Lines of code: ${code}
echo Comment percentage: ${commentpercent} ;;
esac

View File

@ -1,3 +1,25 @@
#!/usr/bin/env sh #!/usr/bin/env bash
coverage erase && coverage run -m unittest discover && coverage report # This file is part of asyncio-taskpool.
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
# If not, see <https://www.gnu.org/licenses/>.
coverage erase
coverage run 2> /dev/null
typeset report=$(coverage report)
typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
if [[ ${total} == 100% ]]; then
echo ${total}
else
echo "${report}"
fi

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

View File

7
docs/source/api/api.rst Normal file
View File

@ -0,0 +1,7 @@
API
===
.. toctree::
:maxdepth: 4
asyncio_taskpool

View File

@ -0,0 +1,7 @@
asyncio\_taskpool.control.client module
=======================================
.. automodule:: asyncio_taskpool.control.client
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,16 @@
asyncio\_taskpool.control package
=================================
.. automodule:: asyncio_taskpool.control
:members:
:undoc-members:
:show-inheritance:
Submodules
----------
.. toctree::
:maxdepth: 4
asyncio_taskpool.control.client
asyncio_taskpool.control.server

View File

@ -0,0 +1,7 @@
asyncio\_taskpool.control.server module
=======================================
.. automodule:: asyncio_taskpool.control.server
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,7 @@
asyncio\_taskpool.exceptions module
===================================
.. automodule:: asyncio_taskpool.exceptions
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,7 @@
asyncio\_taskpool.pool module
=============================
.. automodule:: asyncio_taskpool.pool
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,7 @@
asyncio\_taskpool.queue\_context module
=======================================
.. automodule:: asyncio_taskpool.queue_context
:members:
:undoc-members:
:show-inheritance:

View File

@ -0,0 +1,25 @@
asyncio\_taskpool package
=========================
.. automodule:: asyncio_taskpool
:members:
:undoc-members:
:show-inheritance:
Subpackages
-----------
.. toctree::
:maxdepth: 4
asyncio_taskpool.control
Submodules
----------
.. toctree::
:maxdepth: 4
asyncio_taskpool.exceptions
asyncio_taskpool.pool
asyncio_taskpool.queue_context

60
docs/source/conf.py Normal file
View File

@ -0,0 +1,60 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = 'asyncio-taskpool'
copyright = '2022 Daniil Fajnberg'
author = 'Daniil Fajnberg'
# The full version, including alpha/beta/rc tags
release = '1.1.4'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.duration',
'sphinx.ext.napoleon'
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme_options = {
'style_external_links': True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

58
docs/source/index.rst Normal file
View File

@ -0,0 +1,58 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
Welcome to the asyncio-taskpool documentation!
==============================================
:code:`asyncio-taskpool` is a Python library for dynamically and conveniently managing pools of `asyncio <https://docs.python.org/3/library/asyncio.html>`_ tasks.
Purpose
-------
A `task <https://docs.python.org/3/library/asyncio-task.html>`_ is a very powerful tool of concurrency in the Python world. Since concurrency always implies doing more than one thing a time, you rarely deal with just one :code:`Task` instance. However, managing multiple tasks can become a bit cumbersome quickly, as their number increases. Moreover, especially in long-running code, you may find it useful (or even necessary) to dynamically adjust the extent to which the work is distributed, i.e. increase or decrease the number of tasks.
With that in mind, this library aims to provide two things:
#. An additional layer of abstraction and convenience for managing multiple tasks.
#. A simple interface for dynamically adding and removing tasks when a program is already running.
The first is achieved through the concept of a :doc:`task pool <pages/pool>`. The second is achieved by adding a :doc:`control server <pages/control>` to the task pool.
Installation
------------
.. code-block:: bash
$ pip install asyncio-taskpool
Contents
--------
.. toctree::
:maxdepth: 2
pages/pool
pages/ids
pages/control
api/api
Indices and tables
------------------
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

View File

@ -0,0 +1,107 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
Control interface
=================
When you are dealing with programs that run for a long period of time or even as daemons (i.e. indefinitely), having a way to adjust their behavior without needing to stop and restart them can be desirable.
Task pools offer a high degree of flexibility regarding the number and kind of tasks that run within them, by providing methods to easily start and stop tasks and task groups. But without additional tools, they only allow you to establish a control logic *a priori*, as demonstrated in :ref:`this code snippet <simple-control-logic>`.
What if you have a long-running program that executes certain tasks concurrently, but you don't know in advance how many of them you'll need? What if you want to be able to adjust the number of tasks manually **without stopping the task pool**?
The control server
------------------
The :code:`asyncio-taskpool` library comes with a simple control interface for managing task pools that are already running, at the heart of which is the :py:class:`ControlServer <asyncio_taskpool.control.server.ControlServer>`. Any task pool can be passed to a control server. Once the server is running, you can issue commands to it either via TCP or via UNIX socket. The commands map directly to the task pool methods.
To enable control over a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` via local TCP port :code:`8001`, all you need to do is this:
.. code-block:: python
:caption: main.py
:name: control-server-minimal
from asyncio_taskpool import SimpleTaskPool
from asyncio_taskpool.control import TCPControlServer
from .work import any_worker_func
async def main():
...
pool = SimpleTaskPool(any_worker_func, kwargs={'foo': 42, 'bar': some_object})
control = await TCPControlServer(pool, host='127.0.0.1', port=8001).serve_forever()
await control
Under the hood, the :py:class:`ControlServer <asyncio_taskpool.control.server.ControlServer>` simply uses :code:`asyncio.start_server` for instantiating a socket server. The resulting control task will run indefinitely. Cancelling the control task stops the server.
In reality, you would probably want some graceful handler for an interrupt signal that cancels any remaining tasks as well as the serving control task.
The control client
------------------
Technically, any process that can read from and write to the socket exposed by the control server, will be able to interact with it. The :code:`asyncio-taskpool` package has its own simple implementation in the form of the :py:class:`ControlClient <asyncio_taskpool.control.client.ControlClient>` that makes it easy to use out of the box.
To start a client, you can use the main script of the :py:mod:`asyncio_taskpool.control` sub-package like this:
.. code-block:: bash
$ python -m asyncio_taskpool.control tcp localhost 8001
This would establish a connection to the control server from the previous example. Calling
.. code-block:: bash
$ python -m asyncio_taskpool.control -h
will display the available client options.
The control session
-------------------
Assuming you connected successfully, you should be greeted by the server with a help message and dropped into a simple input prompt.
.. code-block:: none
Connected to SimpleTaskPool-0
Type '-h' to get help and usage instructions for all available commands.
>
The input sent to the server is handled by a typical argument parser, so the interface should be straight-forward. A command like
.. code-block:: none
> start 5
will call the :py:meth:`.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method with :code:`5` as an argument and thus start 5 new tasks in the pool, while the command
.. code-block:: none
> pool-size
will call the :py:meth:`.pool_size <asyncio_taskpool.pool.BaseTaskPool.pool_size>` property getter and return the maximum number of tasks you that can run in the pool.
When you are dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, starting new tasks works just fine, as long as the coroutine functions you want to use can be imported into the namespace of the pool. If you have a function named :code:`worker` in the module :code:`mymodule` under the package :code:`mypackage` and want to use it in a :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` call with the arguments :code:`'x'`, :code:`'x'`, and :code:`'z'`, you would do it like this:
.. code-block:: none
> map mypackage.mymodule.worker ['x','y','z'] -n 3
The :code:`-n` is a shorthand for :code:`--num-concurrent` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
.. note::
The :code:`ast.literal_eval` function from the `standard library <https://docs.python.org/3/library/ast.html#ast.literal_eval>`_ is used to safely evaluate the iterable of arguments to work on. For obvious reasons, being able to provide arbitrary python objects in such a control session is neither practical nor secure. The way this is implemented now is limited in that regard, since you can only use Python literals and containers as arguments for your coroutine functions.
To exit a control session, use the :code:`exit` command or simply press :code:`Ctrl + D`.

42
docs/source/pages/ids.rst Normal file
View File

@ -0,0 +1,42 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
IDs, groups & names
===================
Task IDs
--------
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
Task groups
-----------
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
Pool names
----------
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.

233
docs/source/pages/pool.rst Normal file
View File

@ -0,0 +1,233 @@
.. This file is part of asyncio-taskpool.
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.
.. Copyright © 2022 Daniil Fajnberg
Task pools
==========
What is a task pool?
--------------------
A task pool is an object with a simple interface for aggregating and dynamically managing asynchronous tasks.
To make use of task pools, your code obviously needs to contain coroutine functions (introduced with the :code:`async def` keywords). By adding such functions along with their arguments to a task pool, they are turned into tasks and executed asynchronously.
If you are familiar with the :code:`Pool` class of the `multiprocessing module <https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool>`_ from the standard library, then you should feel at home with the :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` class. Obviously, there are major conceptual and functional differences between the two, but the methods provided by the :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` follow a very similar logic. If you never worked with process or thread pools, don't worry. Task pools are much simpler.
The :code:`TaskPool` class
--------------------------
There are essentially two distinct use cases for a concurrency pool. You want to
#. execute a function *n* times with the same arguments concurrently or
#. execute a function *n* times with different arguments concurrently.
The first is accomplished with the :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` method, while the second is accomplished with the :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` method and its variations :py:meth:`.starmap() <asyncio_taskpool.pool.TaskPool.starmap>` and :py:meth:`.doublestarmap() <asyncio_taskpool.pool.TaskPool.doublestarmap>`.
Let's take a look at an example. Say you have a coroutine function that takes two queues as arguments: The first one being an input-queue (containing items to work on) and the second one being the output queue (for passing on the results to some other function). Your function may look something like this:
.. code-block:: python
:caption: work.py
:name: queue-worker-function
from asyncio.queues import Queue
async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None:
while True:
item = await in_queue.get()
... # Do some work on the item and arrive at a result.
await out_queue.put(result)
How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this:
.. code-block:: python
:caption: main.py
from asyncio.tasks import create_task, gather
from .work import queue_worker_function
...
# We assume that the queues have been initialized already.
tasks = []
for _ in range(5):
new_task = create_task(queue_worker_function(q_in, q_out))
tasks.append(new_task)
# Run some other code and let the tasks do their thing.
...
# At some point, we want the tasks to stop waiting for new items and end.
for task in tasks:
task.cancel()
...
await gather(*tasks)
By contrast, here is how you would do it with a task pool:
.. code-block:: python
:caption: main.py
from asyncio_taskpool import TaskPool
from .work import queue_worker_function
...
pool = TaskPool()
group_name = pool.apply(queue_worker_function, args=(q_in, q_out), num=5)
...
pool.cancel_group(group_name)
...
await pool.flush()
Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
.. code-block:: python
:caption: work.py
:name: another-worker-function
from .my_database_stuff import insert_into_results_table
async def another_worker_function(data: object) -> None:
if data.some_attribute > 1:
...
# Do the work, arrive at results.
await insert_into_results_table(results)
Say we have some *iterator* of data-items (of arbitrary length) that we want to be worked on, and say we want 5 coroutines concurrently working on that data. Here is a very naive task-based solution:
.. code-block:: python
:caption: main.py
from asyncio.tasks import create_task, gather
from .work import another_worker_function
async def main():
...
# We got our data_iterator from somewhere.
keep_going = True
while keep_going:
tasks = []
for _ in range(5):
try:
data = next(data_iterator)
except StopIteration:
keep_going = False
break
new_task = create_task(another_worker_function(data))
tasks.append(new_task)
await gather(*tasks)
Here we already run into problems with the task-based approach. The last line in our :code:`while`-loop blocks until **all 5 tasks** return (or raise an exception). This means that as soon as one of them returns, the number of working coroutines is already less than 5 (until all the others return). This can obviously be solved in different ways. We could, for instance, wrap the creation of new tasks itself in a coroutine, which immediately creates a new task, when one is finished, and then call that coroutine 5 times concurrently. Or we could use the queue-based approach from before, but then we would need to write some queue producing coroutine.
Or we could use a task pool:
.. code-block:: python
:caption: main.py
from asyncio_taskpool import TaskPool
from .work import another_worker_function
async def main():
...
pool = TaskPool()
pool.map(another_worker_function, data_iterator, num_concurrent=5)
...
await pool.gather_and_close()
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`)
.. note::
Neither :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` nor :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` return coroutines. When they are called, the task pool immediately begins scheduling new tasks to run. No :code:`await` needed.
It can't get any simpler than that, can it? So glad you asked...
The :code:`SimpleTaskPool` class
--------------------------------
Let's take the :ref:`queue worker example <queue-worker-function>` from before. If we know that the task pool will only ever work with that one function with the same queue objects, we can make use of the :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` class:
.. code-block:: python
:caption: main.py
from asyncio_taskpool import SimpleTaskPool
from .work import queue_worker_function
async def main():
...
pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out))
pool.start(5)
...
pool.stop_all()
...
await pool.gather_and_close()
This may, at first glance, not seem like much of a difference, aside from different method names. However, assume that our main function runs a loop and needs to be able to periodically regulate the number of tasks being executed in the pool based on some additional variables it receives. With the :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>`, this could not be simpler:
.. code-block:: python
:caption: main.py
:name: simple-control-logic
from asyncio_taskpool import SimpleTaskPool
from .work import queue_worker_function
async def main():
...
pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out))
await pool.start(5)
while True:
...
if some_condition and pool.num_running > 10:
pool.stop(3)
elif some_other_condition and pool.num_running < 5:
pool.start(5)
else:
pool.start(1)
...
await pool.gather_and_close()
Notice how we only specify the function and its arguments during initialization of the pool. From that point on, all we need is the :py:meth:`.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` add :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` methods to adjust the number of concurrently running tasks.
The trade-off here is that this simplified task pool class lacks the flexibility of the regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` class. On an instance of the latter we can call :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` and :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` as often as we like with completely unrelated functions and arguments. With a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>`, once you initialize it, it is pegged to one function and one set of arguments, and all you can do is control the number of tasks working with those.
This simplified interface becomes particularly useful in conjunction with the :doc:`control server <./control>`.
.. _blocking-pool-methods:
(Non-)Blocking pool methods
---------------------------
One of the main concerns when dealing with concurrent programs in general and with :code:`async` functions in particular is when and how a particular piece of code **blocks** during execution, i.e. delays the execution of the following code significantly.
.. note::
Every statement will block to *some* extent. Obviously, when a program does something, that takes time. This is why the proper question to ask is not *if* but *to what extent, under which circumstances* the execution of a particular line of code blocks.
It is fair to assume that anyone reading this is familiar enough with the concepts of asynchronous programming in Python to know that just slapping :code:`async` in front of a function definition will not magically make it suitable for concurrent execution (in any meaningful way). Therefore, we assume that you are dealing with coroutines that can actually unblock the `event loop <https://docs.python.org/3/library/asyncio-eventloop.html>`_ (e.g. doing a significant amount of I/O).
So how does the task pool behave in that regard?
The only method of a pool that one should **always** assume to be blocking is :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>`. This method awaits **all** tasks in the pool, meaning as long as one of them is still running, this coroutine will not return.
.. warning::
This includes awaiting any callbacks that were passed along with the tasks.
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
All methods that add tasks to a pool, i.e. :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants), :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>`, are non-blocking by design. They all make use of "meta tasks" under the hood and return immediately. It is important however, to realize that just because they return, does not mean that any actual tasks have been spawned. For example, if a pool size limit was set and there was "no more room" in the pool when :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` was called, there is **no guarantee** that even a single task has started, when it returns.

View File

@ -1,2 +1,4 @@
-r common.txt -r common.txt
coverage coverage
sphinx
sphinx-rtd-theme

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = asyncio-taskpool name = asyncio-taskpool
version = 0.8.0 version = 1.1.4
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Dynamically manage pools of asyncio tasks description = Dynamically manage pools of asyncio tasks
@ -11,7 +11,7 @@ url = https://git.fajnberg.de/daniil/asyncio-taskpool
project_urls = project_urls =
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
classifiers = classifiers =
Development Status :: 3 - Alpha Development Status :: 5 - Production/Stable
Programming Language :: Python :: 3 Programming Language :: Python :: 3
Operating System :: OS Independent Operating System :: OS Independent
License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
@ -30,6 +30,8 @@ python_requires = >=3.8
[options.extras_require] [options.extras_require]
dev = dev =
coverage coverage
sphinx
sphinx-rtd-theme
[options.packages.find] [options.packages.find]
where = src where = src

View File

@ -14,10 +14,5 @@ See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Brings the main classes up to package level for import convenience.
"""
from .control.server import TCPControlServer, UnixControlServer
from .pool import TaskPool, SimpleTaskPool from .pool import TaskPool, SimpleTaskPool

View File

@ -0,0 +1,2 @@
from .server import TCPControlServer, UnixControlServer
from .client import TCPControlClient, UnixControlClient

View File

@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
CLI client entry point. CLI entry point script for a :class:`ControlClient`.
""" """
@ -24,22 +24,25 @@ from asyncio import run
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
from ..constants import PACKAGE_NAME from ..internals.constants import PACKAGE_NAME
from ..pool import TaskPool from ..pool import TaskPool
from .client import ControlClient, TCPControlClient, UnixControlClient from .client import TCPControlClient, UnixControlClient
from .server import TCPControlServer, UnixControlServer from .server import TCPControlServer, UnixControlServer
__all__ = []
CLIENT_CLASS = 'client_class' CLIENT_CLASS = 'client_class'
UNIX, TCP = 'unix', 'tcp' UNIX, TCP = 'unix', 'tcp'
SOCKET_PATH = 'path' SOCKET_PATH = 'socket_path'
HOST, PORT = 'host', 'port' HOST, PORT = 'host', 'port'
def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]: def parse_cli(args: Sequence[str] = None) -> Dict[str, Any]:
parser = ArgumentParser( parser = ArgumentParser(
prog=f'{PACKAGE_NAME}.control', prog=f'{PACKAGE_NAME}.control',
description=f"Simple CLI based {ControlClient.__name__} for {PACKAGE_NAME}" description=f"Simple CLI based control client for {PACKAGE_NAME}"
) )
subparsers = parser.add_subparsers(title="Connection types") subparsers = parser.add_subparsers(title="Connection types")

View File

@ -27,13 +27,24 @@ from asyncio.streams import StreamReader, StreamWriter, open_connection
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from ..constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES from ..internals.constants import CLIENT_INFO, SESSION_MSG_BYTES
from ..types import ClientConnT, PathT from ..internals.types import ClientConnT, PathT
__all__ = [
'ControlClient',
'TCPControlClient',
'UnixControlClient',
'CLIENT_EXIT'
]
CLIENT_EXIT = 'exit'
class ControlClient(ABC): class ControlClient(ABC):
""" """
Abstract base class for a simple implementation of a task pool control client. Abstract base class for a simple implementation of a pool control client.
Since the server's control interface is simply expecting commands to be sent, any process able to connect to the Since the server's control interface is simply expecting commands to be sent, any process able to connect to the
TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well. TCP or UNIX socket and issue the relevant commands (and optionally read the responses) will work just as well.
@ -58,7 +69,7 @@ class ControlClient(ABC):
raise NotImplementedError raise NotImplementedError
def __init__(self, **conn_kwargs) -> None: def __init__(self, **conn_kwargs) -> None:
"""Simply stores the connection keyword-arguments necessary for opening the connection.""" """Simply stores the keyword-arguments for opening the connection."""
self._conn_kwargs = conn_kwargs self._conn_kwargs = conn_kwargs
self._connected: bool = False self._connected: bool = False
@ -74,6 +85,7 @@ class ControlClient(ABC):
""" """
self._connected = True self._connected = True
writer.write(json.dumps(self._client_info()).encode()) writer.write(json.dumps(self._client_info()).encode())
writer.write(b'\n')
await writer.drain() await writer.drain()
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode()) print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
print("Type '-h' to get help and usage instructions for all available commands.\n") print("Type '-h' to get help and usage instructions for all available commands.\n")
@ -86,21 +98,22 @@ class ControlClient(ABC):
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
Returns: Returns:
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect; `None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase. client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
converted to lowercase.
""" """
try: try:
msg = input("> ").strip().lower() cmd = input("> ").strip().lower()
except EOFError: # Ctrl+D shall be equivalent to the `CLIENT_EXIT` command. except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
msg = CLIENT_EXIT cmd = CLIENT_EXIT
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt. except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
print() print()
return return
if msg == CLIENT_EXIT: if cmd == CLIENT_EXIT:
writer.close() writer.close()
self._connected = False self._connected = False
return return
return msg return cmd or None # will be None if `cmd` is an empty string
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None: async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
""" """
@ -119,6 +132,7 @@ class ControlClient(ABC):
try: try:
# Send the command to the server. # Send the command to the server.
writer.write(cmd.encode()) writer.write(cmd.encode())
writer.write(b'\n')
await writer.drain() await writer.drain()
except ConnectionError as e: except ConnectionError as e:
self._connected = False self._connected = False
@ -129,11 +143,14 @@ class ControlClient(ABC):
async def start(self) -> None: async def start(self) -> None:
""" """
This method opens the pre-defined connection, performs the server-handshake, and enters the interaction loop. Opens connection, performs handshake, and enters interaction loop.
An input prompt is presented to the user and any input is sent (encoded) to the connected server.
One exception is the :const:`CLIENT_EXIT` command (equivalent to Ctrl+D), which merely closes the connection.
If the connection can not be established, an error message is printed to `stderr` and the method returns. If the connection can not be established, an error message is printed to `stderr` and the method returns.
If the `_connected` flag is set to `False` during the interaction loop, the method returns and prints out a If either the exit command is issued or the connection to the server is lost during the interaction loop,
disconnected-message. the method returns and prints out a disconnected-message.
""" """
reader, writer = await self._open_connection(**self._conn_kwargs) reader, writer = await self._open_connection(**self._conn_kwargs)
if reader is None: if reader is None:
@ -146,10 +163,10 @@ class ControlClient(ABC):
class TCPControlClient(ControlClient): class TCPControlClient(ControlClient):
"""Task pool control client that expects a TCP socket to be exposed by the control server.""" """Task pool control client for connecting to a :class:`TCPControlServer`."""
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None: def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments.""" """`host` and `port` are expected as non-optional connection arguments."""
self._host = host self._host = host
self._port = port self._port = port
super().__init__(**conn_kwargs) super().__init__(**conn_kwargs)
@ -169,10 +186,10 @@ class TCPControlClient(ControlClient):
class UnixControlClient(ControlClient): class UnixControlClient(ControlClient):
"""Task pool control client that expects a unix socket to be exposed by the control server.""" """Task pool control client for connecting to a :class:`UnixControlServer`."""
def __init__(self, socket_path: PathT, **conn_kwargs) -> None: def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
"""In addition to what the base class does, the `socket_path` is expected as a non-optional argument.""" """`socket_path` is expected as a non-optional connection argument."""
from asyncio.streams import open_unix_connection from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path) self._socket_path = Path(socket_path)

View File

@ -15,21 +15,31 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
This module contains the the definition of the `ControlParser` class used by a control server. Definition of the :class:`ControlParser` used in a
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
It should not be considered part of the public API.
""" """
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS import logging
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
from ast import literal_eval from ast import literal_eval
from asyncio.streams import StreamWriter
from inspect import Parameter, getmembers, isfunction, signature from inspect import Parameter, getmembers, isfunction, signature
from io import StringIO
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
from ..constants import CLIENT_INFO, CMD, STREAM_WRITER
from ..exceptions import HelpRequested, ParserError from ..exceptions import HelpRequested, ParserError
from ..helpers import get_first_doc_line, resolve_dotted_path from ..internals.constants import CLIENT_INFO, CMD
from ..types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT from ..internals.helpers import get_first_doc_line, resolve_dotted_path
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
__all__ = ['ControlParser']
log = logging.getLogger(__name__)
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter]) FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
@ -42,10 +52,10 @@ NAME, PROG, HELP, DESCRIPTION = 'name', 'prog', 'help', 'description'
class ControlParser(ArgumentParser): class ControlParser(ArgumentParser):
""" """
Subclass of the standard `argparse.ArgumentParser` for remote interaction. Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter` Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
instance passed to it during initialization. `StringIO` instance passed to it during initialization.
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
connected client. connected client.
Finally, it offers some convenience methods and makes use of custom exceptions. Finally, it offers some convenience methods and makes use of custom exceptions.
@ -54,16 +64,18 @@ class ControlParser(ArgumentParser):
@staticmethod @staticmethod
def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls: def help_formatter_factory(terminal_width: int, base_cls: FmtCls = None) -> FmtCls:
""" """
Constructs and returns a subclass of `argparse.HelpFormatter` with a fixed terminal width argument. Constructs and returns a subclass of :class:`argparse.HelpFormatter`
Although a custom formatter class can be explicitly passed into the `ArgumentParser` constructor, this is not The formatter class will have the defined `terminal_width`.
as convenient, when making use of sub-parsers.
Although a custom formatter class can be explicitly passed into the :class:`ArgumentParser` constructor,
this is not as convenient, when making use of sub-parsers.
Args: Args:
terminal_width: terminal_width:
The number of columns of the terminal to which to adjust help formatting. The number of columns of the terminal to which to adjust help formatting.
base_cls (optional): base_cls (optional):
The base class to use for inheritance. By default `argparse.ArgumentDefaultsHelpFormatter` is used. Base class to use for inheritance. By default :class:`argparse.ArgumentDefaultsHelpFormatter` is used.
Returns: Returns:
The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`. The subclass of `base_cls` which fixes the constructor's `width` keyword-argument to `terminal_width`.
@ -77,27 +89,23 @@ class ControlParser(ArgumentParser):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
return ClientHelpFormatter return ClientHelpFormatter
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
**kwargs) -> None:
""" """
Subclass of the `ArgumentParser` geared towards asynchronous interaction with an object "from the outside". Sets some internal attributes in addition to the base class.
Allows directing output to a specified writer rather than stdout/stderr and setting terminal width explicitly.
Args: Args:
stream_writer: stream:
The instance of the `asyncio.StreamWriter` to use for message output. A file-like I/O object to use for message output.
terminal_width (optional): terminal_width (optional):
The terminal width to use for all message formatting. Defaults to `shutil.get_terminal_size().columns`. The terminal width to use for all message formatting. By default the :code:`columns` attribute from
:func:`shutil.get_terminal_size` is taken.
**kwargs(optional): **kwargs(optional):
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
class is specified, it will always be subclassed in the `help_formatter_factory`. class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
""" """
self._stream_writer: StreamWriter = stream_writer self._stream: StringIO = stream
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class')) kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
kwargs.setdefault('exit_on_error', False)
super().__init__(**kwargs) super().__init__(**kwargs)
self._flags: Set[str] = set() self._flags: Set[str] = set()
self._commands = None self._commands = None
@ -105,12 +113,12 @@ class ControlParser(ArgumentParser):
def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT, def add_function_command(self, function: Callable, omit_params: Container[str] = OMIT_PARAMS_DEFAULT,
**subparser_kwargs) -> 'ControlParser': **subparser_kwargs) -> 'ControlParser':
""" """
Takes a function along with its parameters and adds a corresponding (sub-)command to the parser. Takes a function and adds a corresponding (sub-)command to the parser.
The `add_subparsers` method must have been called prior to this. The :meth:`add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. NOTE: Currently, only a limited spectrum of parameters can be accurately converted to parser arguments.
This method works correctly with any public method of the `SimpleTaskPool` class. This method works correctly with any public method of the any task pool class.
Args: Args:
function: function:
@ -118,7 +126,7 @@ class ControlParser(ArgumentParser):
omit_params (optional): omit_params (optional):
Names of function parameters not to add as parser arguments. Names of function parameters not to add as parser arguments.
**subparser_kwargs (optional): **subparser_kwargs (optional):
Passed directly to the `add_parser` method. Passed directly to the :meth:`add_parser` method.
Returns: Returns:
The subparser instance created from the function. The subparser instance created from the function.
@ -133,7 +141,7 @@ class ControlParser(ArgumentParser):
def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser': def add_property_command(self, prop: property, cls_name: str = '', **subparser_kwargs) -> 'ControlParser':
""" """
Same as the `add_function_command` method, but for properties. Same as the :meth:`add_function_command` method, but for properties.
Args: Args:
prop: prop:
@ -141,7 +149,7 @@ class ControlParser(ArgumentParser):
cls_name (optional): cls_name (optional):
Name of the class the property is defined on to appear in the command help text. Name of the class the property is defined on to appear in the command help text.
**subparser_kwargs (optional): **subparser_kwargs (optional):
Passed directly to the `add_parser` method. Passed directly to the :meth:`add_parser` method.
Returns: Returns:
The subparser instance created from the property. The subparser instance created from the property.
@ -164,12 +172,12 @@ class ControlParser(ArgumentParser):
def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (), def add_class_commands(self, cls: Type, public_only: bool = True, omit_members: Container[str] = (),
member_arg_name: str = CMD) -> ParsersDict: member_arg_name: str = CMD) -> ParsersDict:
""" """
Takes a class and adds its methods and properties as (sub-)commands to the parser. Adds methods/properties of a class as (sub-)commands to the parser.
The `add_subparsers` method must have been called prior to this. The :meth:`add_subparsers` method must have been called prior to this.
NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments. NOTE: Currently, only a limited spectrum of function parameters can be accurately converted to parser arguments.
This method works correctly with the `SimpleTaskPool` class. This method works correctly with any task pool class.
Args: Args:
cls: cls:
@ -181,13 +189,12 @@ class ControlParser(ArgumentParser):
member_arg_name (optional): member_arg_name (optional):
After parsing the arguments, depending on which command was invoked by the user, the corresponding After parsing the arguments, depending on which command was invoked by the user, the corresponding
method/property will be stored as an extra argument in the parsed namespace under this attribute name. method/property will be stored as an extra argument in the parsed namespace under this attribute name.
Defaults to `constants.CMD`.
Returns: Returns:
Dictionary mapping class member names to the (sub-)parsers created from them. Dictionary mapping class member names to the (sub-)parsers created from them.
""" """
parsers: ParsersDict = {} parsers: ParsersDict = {}
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width} common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
for name, member in getmembers(cls): for name, member in getmembers(cls):
if name in omit_members or (name.startswith('_') and public_only): if name in omit_members or (name.startswith('_') and public_only):
continue continue
@ -202,14 +209,14 @@ class ControlParser(ArgumentParser):
return parsers return parsers
def add_subparsers(self, *args, **kwargs): def add_subparsers(self, *args, **kwargs):
"""Adds the subparsers action as an internal attribute before returning it.""" """Adds the subparsers action as an attribute before returning it."""
self._commands = super().add_subparsers(*args, **kwargs) self._commands = super().add_subparsers(*args, **kwargs)
return self._commands return self._commands
def _print_message(self, message: str, *args, **kwargs) -> None: def _print_message(self, message: str, *args, **kwargs) -> None:
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer.""" """This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
if message: if message:
self._stream_writer.write(message.encode()) self._stream.write(message)
def exit(self, status: int = 0, message: str = None) -> None: def exit(self, status: int = 0, message: str = None) -> None:
"""This is overridden to prevent system exit to be invoked.""" """This is overridden to prevent system exit to be invoked."""
@ -217,28 +224,28 @@ class ControlParser(ArgumentParser):
self._print_message(message) self._print_message(message)
def error(self, message: str) -> None: def error(self, message: str) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method.""" """Raises the :exc:`ParserError <asyncio_taskpool.exceptions.ParserError>` exception at the end."""
super().error(message=message) super().error(message=message)
raise ParserError raise ParserError
def print_help(self, file=None) -> None: def print_help(self, file=None) -> None:
"""This just adds the custom `HelpRequested` exception after the parent class' method.""" """Raises the :exc:`HelpRequested <asyncio_taskpool.exceptions.HelpRequested>` exception at the end."""
super().print_help(file) super().print_help(file)
raise HelpRequested raise HelpRequested
def add_function_arg(self, parameter: Parameter, **kwargs) -> Action: def add_function_arg(self, parameter: Parameter, **kwargs) -> Action:
""" """
Takes an `inspect.Parameter` of a function and adds a corresponding argument to the parser. Takes an :class:`inspect.Parameter` and adds a corresponding parser argument.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any parameter of any public method of the `SimpleTaskPool` class. This method works correctly with any parameter of any public method any task pool class.
Args: Args:
parameter: The `inspect.Parameter` object to be converted to a parser argument. parameter: The :class:`inspect.Parameter` object to be converted to a parser argument.
**kwargs: Passed to the `add_argument` method of the base class. **kwargs: Passed to the :meth:`add_argument` method of the base class.
Returns: Returns:
The `argparse.Action` returned by the `add_argument` method. The :class:`argparse.Action` returned by the :meth:`add_argument` method.
""" """
if parameter.default is Parameter.empty: if parameter.default is Parameter.empty:
# A non-optional function parameter should correspond to a positional argument. # A non-optional function parameter should correspond to a positional argument.
@ -273,10 +280,10 @@ class ControlParser(ArgumentParser):
def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None: def add_function_args(self, function: Callable, omit: Container[str] = OMIT_PARAMS_DEFAULT) -> None:
""" """
Takes a function reference and adds its parameters as arguments to the parser. Takes a function and adds its parameters as arguments to the parser.
NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument. NOTE: Currently, only a limited spectrum of parameters can be accurately converted to a parser argument.
This method works correctly with any public method of the `SimpleTaskPool` class. This method works correctly with any public method of any task pool class.
Args: Args:
function: function:
@ -297,14 +304,37 @@ def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments. Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
See: https://bugs.python.org/issue36078 See: https://bugs.python.org/issue36078
In addition, the type conversion wrapper catches exceptions not handled properly by the parser, logs them, and
turns them into `ArgumentTypeError` exceptions the parser can propagate to the client.
""" """
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg) def wrapper(arg: Any) -> Any:
if arg is SUPPRESS:
return arg
try:
return cls(arg)
except (ArgumentTypeError, TypeError, ValueError):
raise # handled properly by the parser and propagated to the client anyway
except Exception as e:
text = f"{e.__class__.__name__} occurred in parser trying to convert type: {cls.__name__}({repr(arg)})"
log.exception(text)
raise ArgumentTypeError(text) # propagate to the client
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed. # Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
wrapper.__name__ = cls.__name__ wrapper.__name__ = cls.__name__
return wrapper return wrapper
def _get_type_from_annotation(annotation: Type) -> Callable[[Any], Any]: def _get_type_from_annotation(annotation: Type) -> Callable[[Any], Any]:
"""
Returns a type conversion function based on the `annotation` passed.
Required to properly convert parsed arguments to the type expected by certain pool methods.
Each conversion function is wrapped by `_get_arg_type_wrapper`.
`Callable`-type annotations give the `resolve_dotted_path` function.
`Iterable`- or args/kwargs-type annotations give the `ast.literal_eval` function.
Others pass unchanged (but still wrapped with `_get_arg_type_wrapper`).
"""
if any(annotation is t for t in {CoroutineFunc, EndCB, CancelCB}): if any(annotation is t for t in {CoroutineFunc, EndCB, CancelCB}):
annotation = resolve_dotted_path annotation = resolve_dotted_path
if any(annotation is t for t in {ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]}): if any(annotation is t for t in {ArgsT, KwArgsT, Iterable[ArgsT], Iterable[KwArgsT]}):

View File

@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
This module contains the task pool control server class definitions. Task pool control server class definitions.
""" """
@ -28,10 +28,14 @@ from asyncio.tasks import Task, create_task
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from ..pool import TaskPool, SimpleTaskPool
from ..types import ConnectedCallbackT
from .client import ControlClient, TCPControlClient, UnixControlClient from .client import ControlClient, TCPControlClient, UnixControlClient
from .session import ControlSession from .session import ControlSession
from ..pool import AnyTaskPoolT
from ..internals.helpers import classmethod
from ..internals.types import ConnectedCallbackT, PathT
__all__ = ['ControlServer', 'TCPControlServer', 'UnixControlServer']
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -41,17 +45,52 @@ class ControlServer(ABC):
""" """
Abstract base class for a task pool control server. Abstract base class for a task pool control server.
This class acts as a wrapper around an async server instance and initializes a `ControlSession` upon a client This class acts as a wrapper around an async server instance and initializes a
connecting to it. The entire interface is defined within that session class. :class:`ControlSession <asyncio_taskpool.control.session.ControlSession>` once a client connects to it.
The interface is defined within the session class.
""" """
_client_class = ControlClient _client_class = ControlClient
@classmethod @classmethod
@property @property
def client_class_name(cls) -> str: def client_class_name(cls) -> str:
"""Returns the name of the control client class matching the server class.""" """Returns the name of the matching control client class."""
return cls._client_class.__name__ return cls._client_class.__name__
def __init__(self, pool: AnyTaskPoolT, **server_kwargs) -> None:
"""
Merely sets internal attributes, but does not start the server yet.
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
tied to one specific task pool.
Args:
pool:
An instance of a `BaseTaskPool` subclass to tie the server to.
**server_kwargs (optional):
Keyword arguments that will be passed into the function that starts the server.
"""
self._pool: AnyTaskPoolT = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
@property
def pool(self) -> AnyTaskPoolT:
"""The task pool instance controlled by the server."""
return self._pool
def is_serving(self) -> bool:
"""Wrapper around the `asyncio.Server.is_serving` method."""
return self._server.is_serving()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
The universal client callback that will be passed into the `_get_server_instance` method.
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
"""
session = ControlSession(self, reader, writer)
await session.client_handshake()
await session.listen()
@abstractmethod @abstractmethod
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
""" """
@ -74,40 +113,6 @@ class ControlServer(ABC):
"""The method to run after the server's `serve_forever` methods ends for whatever reason.""" """The method to run after the server's `serve_forever` methods ends for whatever reason."""
raise NotImplementedError raise NotImplementedError
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None:
"""
Initializes by merely saving the internal attributes, but without starting the server yet.
The task pool must be passed here and can not be set/changed afterwards. This means a control server is always
tied to one specific task pool.
Args:
pool:
An instance of a `BaseTaskPool` subclass to tie the server to.
**server_kwargs (optional):
Keyword arguments that will be passed into the function that starts the server.
"""
self._pool: Union[TaskPool, SimpleTaskPool] = pool
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
@property
def pool(self) -> Union[TaskPool, SimpleTaskPool]:
"""Read-only property for accessing the task pool instance controlled by the server."""
return self._pool
def is_serving(self) -> bool:
"""Wrapper around the `asyncio.Server.is_serving` method."""
return self._server.is_serving()
async def _client_connected_cb(self, reader: StreamReader, writer: StreamWriter) -> None:
"""
The universal client callback that will be passed into the `_get_server_instance` method.
Instantiates a control session, performs the client handshake, and enters the session's `listen` loop.
"""
session = ControlSession(self, reader, writer)
await session.client_handshake()
await session.listen()
async def _serve_forever(self) -> None: async def _serve_forever(self) -> None:
""" """
To be run as an `asyncio.Task` by the following method. To be run as an `asyncio.Task` by the following method.
@ -124,9 +129,12 @@ class ControlServer(ABC):
async def serve_forever(self) -> Task: async def serve_forever(self) -> Task:
""" """
This method actually starts the server and begins listening to client connections on the specified interface. Starts the server and begins listening to client connections.
It should never block because the serving will be performed in a separate task. It should never block because the serving will be performed in a separate task.
Returns:
The forever serving task. To stop the server, this task should be cancelled.
""" """
log.debug("Starting %s...", self.__class__.__name__) log.debug("Starting %s...", self.__class__.__name__)
self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs) self._server = await self._get_server_instance(self._client_connected_cb, **self._server_kwargs)
@ -134,12 +142,13 @@ class ControlServer(ABC):
class TCPControlServer(ControlServer): class TCPControlServer(ControlServer):
"""Task pool control server class that exposes a TCP socket for control clients to connect to.""" """Exposes a TCP socket for control clients to connect to."""
_client_class = TCPControlClient _client_class = TCPControlClient
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: def __init__(self, pool: AnyTaskPoolT, host: str, port: Union[int, str], **server_kwargs) -> None:
self._host = server_kwargs.pop('host') """`host` and `port` are expected as non-optional server arguments."""
self._port = server_kwargs.pop('port') self._host = host
self._port = port
super().__init__(pool, **server_kwargs) super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
@ -152,13 +161,14 @@ class TCPControlServer(ControlServer):
class UnixControlServer(ControlServer): class UnixControlServer(ControlServer):
"""Task pool control server class that exposes a unix socket for control clients to connect to.""" """Exposes a unix socket for control clients to connect to."""
_client_class = UnixControlClient _client_class = UnixControlClient
def __init__(self, pool: Union[TaskPool, SimpleTaskPool], **server_kwargs) -> None: def __init__(self, pool: AnyTaskPoolT, socket_path: PathT, **server_kwargs) -> None:
"""`socket_path` is expected as a non-optional server argument."""
from asyncio.streams import start_unix_server from asyncio.streams import start_unix_server
self._start_unix_server = start_unix_server self._start_unix_server = start_unix_server
self._socket_path = Path(server_kwargs.pop('path')) self._socket_path = Path(socket_path)
super().__init__(pool, **server_kwargs) super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:

View File

@ -15,7 +15,9 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
This module contains the the definition of the `ControlSession` class used by the control server. Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
It should not be considered part of the public API.
""" """
@ -24,32 +26,36 @@ import json
from argparse import ArgumentError from argparse import ArgumentError
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from inspect import isfunction, signature from inspect import isfunction, signature
from io import StringIO
from typing import Callable, Optional, Union, TYPE_CHECKING from typing import Callable, Optional, Union, TYPE_CHECKING
from ..constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
from ..exceptions import CommandError, HelpRequested, ParserError
from ..helpers import return_or_exception
from ..pool import TaskPool, SimpleTaskPool
from .parser import ControlParser from .parser import ControlParser
from ..exceptions import CommandError, HelpRequested, ParserError
from ..pool import TaskPool, SimpleTaskPool
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
from ..internals.helpers import return_or_exception
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import ControlServer from .server import ControlServer
__all__ = ['ControlSession']
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ControlSession: class ControlSession:
""" """
This class defines the API for controlling a task pool instance from the outside. Manages a single control session between a server and a client.
The commands received from a connected client are translated into method calls on the task pool instance. The commands received from a connected client are translated into method calls on the task pool instance.
A subclass of the standard `argparse.ArgumentParser` is used to handle the input read from the stream. A subclass of the standard :class:`argparse.ArgumentParser` is used to handle the input read from the stream.
""" """
def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None: def __init__(self, server: 'ControlServer', reader: StreamReader, writer: StreamWriter) -> None:
""" """
Instantiation should happen once a client connection to the control server has already been established. Connection to the control server should already been established.
For more convenient/efficient access, some of the server's properties are saved in separate attributes. For more convenient/efficient access, some of the server's properties are saved in separate attributes.
The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during The argument parser is _not_ instantiated in the constructor. It requires a bit of client information during
@ -57,7 +63,7 @@ class ControlSession:
Args: Args:
server: server:
The instance of a `ControlServer` subclass starting the session. The instance of a :class:`ControlServer` subclass starting the session.
reader: reader:
The `asyncio.StreamReader` created when a client connected to the server. The `asyncio.StreamReader` created when a client connected to the server.
writer: writer:
@ -69,14 +75,16 @@ class ControlSession:
self._reader: StreamReader = reader self._reader: StreamReader = reader
self._writer: StreamWriter = writer self._writer: StreamWriter = writer
self._parser: Optional[ControlParser] = None self._parser: Optional[ControlParser] = None
self._response_buffer: StringIO = StringIO()
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None: async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
""" """
Takes a pool method reference, executes it, and writes a response accordingly. Takes a pool method reference, executes it, and writes a response accordingly.
If the first parameter is named `self`, the method will be called with the `_pool` instance as its first If the first parameter is named `self`, the method will be called with the `_pool` instance as its first
positional argument. If it returns nothing, the response upon successful execution will be `constants.CMD_OK`, positional argument.
otherwise the response written to the stream will be its return value (as an encoded string). If it returns nothing, the response upon successful execution will be :const:`constants.CMD_OK`, otherwise the
response written to the stream will be its return value (as an encoded string).
Args: Args:
prop: prop:
@ -95,7 +103,7 @@ class ControlSession:
elif param.kind == param.VAR_POSITIONAL: elif param.kind == param.VAR_POSITIONAL:
var_pos = kwargs.pop(param.name) var_pos = kwargs.pop(param.name)
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs) output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
self._writer.write(CMD_OK if output is None else str(output).encode()) self._response_buffer.write(CMD_OK.decode() if output is None else str(output))
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None: async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
""" """
@ -108,27 +116,29 @@ class ControlSession:
The reference to the property defined on the `_pool` instance's class. The reference to the property defined on the `_pool` instance's class.
**kwargs (optional): **kwargs (optional):
If not empty, the property setter is executed and the keyword arguments are passed along to it; the If not empty, the property setter is executed and the keyword arguments are passed along to it; the
response upon successful execution will be `constants.CMD_OK`. Otherwise the property getter is response upon successful execution will be :const:`constants.CMD_OK`. Otherwise the property getter is
executed and the response written to the stream will be its return value (as an encoded string). executed and the response written to the stream will be its return value (as an encoded string).
""" """
if kwargs: if kwargs:
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__) log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
await return_or_exception(prop.fset, self._pool, **kwargs) await return_or_exception(prop.fset, self._pool, **kwargs)
self._writer.write(CMD_OK) self._response_buffer.write(CMD_OK.decode())
else: else:
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__) log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode()) self._response_buffer.write(str(await return_or_exception(prop.fget, self._pool)))
async def client_handshake(self) -> None: async def client_handshake(self) -> None:
""" """
This method must be invoked before starting any other client interaction. Must be invoked before starting any other client interaction.
Client info is retrieved, server info is sent back, and the `ControlParser` is initialized and configured. Client info is retrieved, server info is sent back, and the
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
""" """
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip()) msg = (await self._reader.readline()).decode().strip()
client_info = json.loads(msg)
log.debug("%s connected", self._client_class_name) log.debug("%s connected", self._client_class_name)
parser_kwargs = { parser_kwargs = {
STREAM_WRITER: self._writer, 'stream': self._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH], CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -137,16 +147,16 @@ class ControlSession:
self._parser.add_subparsers(title="Commands", self._parser.add_subparsers(title="Commands",
metavar="(A command followed by '-h' or '--help' will show command-specific help.)") metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
self._parser.add_class_commands(self._pool.__class__) self._parser.add_class_commands(self._pool.__class__)
self._writer.write(str(self._pool).encode()) self._writer.write(str(self._pool).encode() + b'\n')
await self._writer.drain() await self._writer.drain()
async def _parse_command(self, msg: str) -> None: async def _parse_command(self, msg: str) -> None:
""" """
Takes a message from the client and attempts to parse it. Takes a message from the client and attempts to parse it.
If a parsing error occurs, it is returned to the client. If the `HelpRequested` exception was raised by the If a parsing error occurs, it is returned to the client. If the :exc:`HelpRequested` exception was raised by the
`ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the entire :class:`ControlParser`, nothing else happens. Otherwise, the appropriate `_exec...` method is called with the
dictionary of keyword-arguments returned by the `ControlParser` passed into it. entire dictionary of keyword-arguments returned by the :class:`ControlParser` passed into it.
Args: Args:
msg: The non-empty string read from the client stream. msg: The non-empty string read from the client stream.
@ -155,7 +165,7 @@ class ControlSession:
kwargs = vars(self._parser.parse_args(msg.split(' '))) kwargs = vars(self._parser.parse_args(msg.split(' ')))
except ArgumentError as e: except ArgumentError as e:
log.debug("%s got an ArgumentError", self._client_class_name) log.debug("%s got an ArgumentError", self._client_class_name)
self._writer.write(str(e).encode()) self._response_buffer.write(str(e))
return return
except (HelpRequested, ParserError): except (HelpRequested, ParserError):
log.debug("%s received usage help", self._client_class_name) log.debug("%s received usage help", self._client_class_name)
@ -166,20 +176,25 @@ class ControlSession:
elif isinstance(command, property): elif isinstance(command, property):
await self._exec_property_and_respond(command, **kwargs) await self._exec_property_and_respond(command, **kwargs)
else: else:
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode()) self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
async def listen(self) -> None: async def listen(self) -> None:
""" """
Enters the main control loop that only ends if either the server or the client disconnect. Enters the main control loop listening to client input.
Messages from the client are read and passed into the `_parse_command` method, which handles the rest. This method only returns if either the server or the client disconnect.
Messages from the client are read, parsed, and turned into pool commands (if possible).
This method should be called, when the client connection was established and the handshake was successful. This method should be called, when the client connection was established and the handshake was successful.
It will obviously block indefinitely. It will obviously block indefinitely.
""" """
while self._control_server.is_serving(): while self._control_server.is_serving():
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip() msg = (await self._reader.readline()).decode().strip()
if not msg: if not msg:
log.debug("%s disconnected", self._client_class_name) log.debug("%s disconnected", self._client_class_name)
break break
await self._parse_command(msg) await self._parse_command(msg)
response = self._response_buffer.getvalue() + "\n"
self._response_buffer.seek(0)
self._response_buffer.truncate()
self._writer.write(response.encode())
await self._writer.drain() await self._writer.drain()

View File

@ -51,10 +51,6 @@ class InvalidGroupName(PoolException):
pass pass
class PoolStillUnlocked(PoolException):
pass
class NotCoroutine(PoolException): class NotCoroutine(PoolException):
pass pass

View File

@ -1,85 +0,0 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Miscellaneous helper functions. None of these should be considered part of the public API.
"""
from asyncio.coroutines import iscoroutinefunction
from asyncio.queues import Queue
from importlib import import_module
from inspect import getdoc
from typing import Any, Optional, Union
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
async def join_queue(q: Queue) -> None:
await q.join()
def get_first_doc_line(obj: object) -> str:
return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
try:
if iscoroutinefunction(_function_to_execute):
return await _function_to_execute(*args, **kwargs)
else:
return _function_to_execute(*args, **kwargs)
except Exception as e:
return e
def resolve_dotted_path(dotted_path: str) -> object:
"""
Resolves a dotted path to a global object and returns that object.
Algorithm shamelessly stolen from the `logging.config` module from the standard library.
"""
names = dotted_path.split('.')
module_name = names.pop(0)
found = import_module(module_name)
for name in names:
try:
found = getattr(found, name)
except AttributeError:
module_name += f'.{name}'
import_module(module_name)
found = getattr(found, name)
return found

View File

@ -16,19 +16,22 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Constants used by more than one module in the package. Constants used by more than one module in the package.
This module should **not** be considered part of the public API.
""" """
import sys
PACKAGE_NAME = 'asyncio_taskpool' PACKAGE_NAME = 'asyncio_taskpool'
DEFAULT_TASK_GROUP = '' PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
CLIENT_EXIT = 'exit' DEFAULT_TASK_GROUP = 'default'
SESSION_MSG_BYTES = 1024 * 100 SESSION_MSG_BYTES = 1024 * 100
STREAM_WRITER = 'stream_writer'
CMD = 'command' CMD = 'command'
CMD_OK = b"ok" CMD_OK = b"ok"

View File

@ -15,7 +15,9 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
This module contains the definition of the `TaskGroupRegister` class. Definition of :class:`TaskGroupRegister`.
It should not be considered part of the public API.
""" """
@ -26,9 +28,9 @@ from typing import Iterator, Set
class TaskGroupRegister(MutableSet): class TaskGroupRegister(MutableSet):
""" """
This class combines the interface of a regular `set` with that of the `asyncio.Lock`. Combines the interface of a regular `set` with that of the `asyncio.Lock`.
It serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for Serves simultaneously as a container of IDs of tasks that belong to the same group, and as a mechanism for
preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of preventing race conditions within a task group. The lock should be acquired before cancelling the entire group of
tasks, as well as before starting a task within the group. tasks, as well as before starting a task within the group.
""" """

View File

@ -0,0 +1,157 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Miscellaneous helper functions. None of these should be considered part of the public API.
"""
import builtins
from asyncio.coroutines import iscoroutinefunction
from importlib import import_module
from inspect import getdoc
from typing import Any, Callable, Optional, Type, Union
from .constants import PYTHON_BEFORE_39
from .types import T, AnyCallableT, ArgsT, KwArgsT
async def execute_optional(function: AnyCallableT, args: ArgsT = (), kwargs: KwArgsT = None) -> Optional[T]:
"""
Runs `function` with `args` and `kwargs` and returns its output.
Args:
function:
Any callable that accepts the provided positional and keyword-arguments.
If it is a coroutine function, it will be awaited.
If it is not a callable, nothing is returned.
*args (optional):
Positional arguments to pass to `function`.
**kwargs (optional):
Keyword-arguments to pass to `function`.
Returns:
Whatever `function` returns (possibly after being awaited) or `None` if `function` is not callable.
"""
if not callable(function):
return
if kwargs is None:
kwargs = {}
if iscoroutinefunction(function):
return await function(*args, **kwargs)
return function(*args, **kwargs)
def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
"""
Calls `function` passing `arg` to it, optionally unpacking it first.
Args:
function:
Any callable that accepts the provided argument(s).
arg:
The single positional argument that `function` expects; in this case `arg_stars` should be 0.
Or the iterable of positional arguments that `function` expects; in this case `arg_stars` should be 1.
Or the mapping of keyword-arguments that `function` expects; in this case `arg_stars` should be 2.
arg_stars (optional):
Determines if and how to unpack `arg`.
0 means no unpacking, i.e. `arg` is passed into `function` directly as `function(arg)`.
1 means unpacking to an arbitrary number of positional arguments, i.e. as `function(*arg)`.
2 means unpacking to an arbitrary number of keyword-arguments, i.e. as `function(**arg)`.
Returns:
Whatever `function` returns.
Raises:
`ValueError`: `arg_stars` is something other than 0, 1, or 2.
"""
if arg_stars == 0:
return function(arg)
if arg_stars == 1:
return function(*arg)
if arg_stars == 2:
return function(**arg)
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
def get_first_doc_line(obj: object) -> str:
"""Takes an object and returns the first (non-empty) line of its docstring."""
return getdoc(obj).strip().split("\n", 1)[0].strip()
async def return_or_exception(_function_to_execute: AnyCallableT, *args, **kwargs) -> Union[T, Exception]:
"""
Returns the output of a function or the exception thrown during its execution.
Args:
_function_to_execute:
Any callable that accepts the provided positional and keyword-arguments.
*args (optional):
Positional arguments to pass to `_function_to_execute`.
**kwargs (optional):
Keyword-arguments to pass to `_function_to_execute`.
Returns:
Whatever `_function_to_execute` returns or throws. (An exception is not raised, but returned!)
"""
try:
if iscoroutinefunction(_function_to_execute):
return await _function_to_execute(*args, **kwargs)
else:
return _function_to_execute(*args, **kwargs)
except Exception as e:
return e
def resolve_dotted_path(dotted_path: str) -> object:
"""
Resolves a dotted path to a global object and returns that object.
Algorithm shamelessly stolen from the `logging.config` module from the standard library.
"""
names = dotted_path.split('.')
module_name = names.pop(0)
found = import_module(module_name)
for name in names:
try:
found = getattr(found, name)
except AttributeError:
module_name += f'.{name}'
import_module(module_name)
found = getattr(found, name)
return found
class ClassMethodWorkaround:
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
def __init__(self, method_or_property: Union[Callable, property]) -> None:
if isinstance(method_or_property, property):
self._getter = method_or_property.fget
else:
self._getter = method_or_property
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
if obj is None:
return self._getter(cls)
return self._getter(obj)
# Starting with Python 3.9, this is thankfully no longer necessary.
if PYTHON_BEFORE_39:
classmethod = ClassMethodWorkaround
else:
classmethod = builtins.classmethod

View File

@ -16,12 +16,14 @@ If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
Custom type definitions used in various modules. Custom type definitions used in various modules.
This module should **not** be considered part of the public API.
""" """
from asyncio.streams import StreamReader, StreamWriter from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path from pathlib import Path
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
T = TypeVar('T') T = TypeVar('T')
@ -29,8 +31,8 @@ T = TypeVar('T')
ArgsT = Iterable[Any] ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any] KwArgsT = Mapping[str, Any]
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]] AnyCallableT = Callable[..., Union[T, Awaitable[T]]]
CoroutineFunc = Callable[[...], Awaitable[Any]] CoroutineFunc = Callable[..., Coroutine]
EndCB = Callable EndCB = Callable
CancelCB = Callable CancelCB = Callable

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,7 @@ You should have received a copy of the GNU Lesser General Public License along w
If not, see <https://www.gnu.org/licenses/>.""" If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """ __doc__ = """
This module contains the definition of an `asyncio.Queue` subclass. Definition of an :code:`asyncio.Queue` subclass with some small additions.
""" """
@ -23,12 +23,20 @@ from asyncio.queues import Queue as _Queue
from typing import Any from typing import Any
__all__ = ['Queue']
class Queue(_Queue): class Queue(_Queue):
"""This just adds a little syntactic sugar to the `asyncio.Queue`.""" """
Adds a little syntactic sugar to the :code:`asyncio.Queue`.
Allows being used as an async context manager awaiting `get` upon entering the context and calling
:meth:`item_processed` upon exiting it.
"""
def item_processed(self) -> None: def item_processed(self) -> None:
""" """
Does exactly the same as `task_done()`. Does exactly the same as :meth:`asyncio.Queue.task_done`.
This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing, This method exists because `task_done` is an atrocious name for the method. It communicates the wrong thing,
invites confusion, and immensely reduces readability (in the context of this library). And readability counts. invites confusion, and immensely reduces readability (in the context of this library). And readability counts.
@ -39,7 +47,7 @@ class Queue(_Queue):
""" """
Implements an asynchronous context manager for the queue. Implements an asynchronous context manager for the queue.
Upon entering `get()` is awaited and subsequently whatever came out of the queue is returned. Upon entering :meth:`get` is awaited and subsequently whatever came out of the queue is returned.
It allows writing code this way: It allows writing code this way:
>>> queue = Queue() >>> queue = Queue()
>>> ... >>> ...
@ -52,7 +60,7 @@ class Queue(_Queue):
""" """
Implements an asynchronous context manager for the queue. Implements an asynchronous context manager for the queue.
Upon exiting `item_processed()` is called. This is why this context manager may not always be what you want, Upon exiting :meth:`item_processed` is called. This is why this context manager may not always be what you want,
but in some situations it makes the code much cleaner. but in some situations it makes the code much cleaner.
""" """
self.item_processed() self.item_processed()

30
tests/__main__.py Normal file
View File

@ -0,0 +1,30 @@
__author__ = "Daniil Fajnberg"
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
__license__ = """GNU LGPLv3.0
This file is part of asyncio-taskpool.
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
If not, see <https://www.gnu.org/licenses/>."""
__doc__ = """
Main entry point for all unit tests.
"""
import sys
import unittest
if __name__ == '__main__':
test_suite = unittest.defaultTestLoader.discover('.')
test_runner = unittest.TextTestRunner(resultclass=unittest.TextTestResult)
result = test_runner.run(test_suite)
sys.exit(not result.wasSuccessful())

View File

@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
mock_client = MagicMock(start=mock_client_start) mock_client = MagicMock(start=mock_client_start)
mock_client_cls = MagicMock(return_value=mock_client) mock_client_cls = MagicMock(return_value=mock_client)
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789} mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
self.assertIsNone(await module.main()) self.assertIsNone(await module.main())
mock_parse_cli.assert_called_once_with() mock_parse_cli.assert_called_once_with()
mock_client_cls.assert_called_once_with(**mock_client_kwargs) mock_client_cls.assert_called_once_with(**mock_client_kwargs)

View File

@ -28,7 +28,7 @@ from unittest import IsolatedAsyncioTestCase, skipIf
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, MagicMock, call, patch
from asyncio_taskpool.control import client from asyncio_taskpool.control import client
from asyncio_taskpool.constants import CLIENT_INFO, SESSION_MSG_BYTES from asyncio_taskpool.internals.constants import CLIENT_INFO, SESSION_MSG_BYTES
FOO, BAR = 'foo', 'bar' FOO, BAR = 'foo', 'bar'
@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
self.assertTrue(self.client._connected) self.assertTrue(self.client._connected)
mock__client_info.assert_called_once_with() mock__client_info.assert_called_once_with()
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode()) self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_has_calls([ self.mock_print.assert_has_calls([
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
mock__get_command.return_value = cmd = FOO + BAR + ' 123' mock__get_command.return_value = cmd = FOO + BAR + ' 123'
self.mock_drain.side_effect = err = ConnectionError() self.mock_drain.side_effect = err = ConnectionError()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode()) self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_not_awaited() self.mock_read.assert_not_awaited()
self.mock_print.assert_called_once_with(err, file=sys.stderr) self.mock_print.assert_called_once_with(err, file=sys.stderr)
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
self.mock_print.reset_mock() self.mock_print.reset_mock()
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer)) self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
self.mock_write.assert_called_once_with(cmd.encode()) self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
self.mock_drain.assert_awaited_once_with() self.mock_drain.assert_awaited_once_with()
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
self.mock_print.assert_called_once_with(FOO) self.mock_print.assert_called_once_with(FOO)

View File

@ -28,22 +28,22 @@ from typing import Iterable
from asyncio_taskpool.control import parser from asyncio_taskpool.control import parser
from asyncio_taskpool.exceptions import HelpRequested, ParserError from asyncio_taskpool.exceptions import HelpRequested, ParserError
from asyncio_taskpool.helpers import resolve_dotted_path from asyncio_taskpool.internals.helpers import resolve_dotted_path
from asyncio_taskpool.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT from asyncio_taskpool.internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
FOO, BAR = 'foo', 'bar' FOO, BAR = 'foo', 'bar'
class ControlServerTestCase(TestCase): class ControlParserTestCase(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory') self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start() self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
self.stream_writer, self.terminal_width = MagicMock(), 420 self.stream, self.terminal_width = MagicMock(), 420
self.kwargs = { self.kwargs = {
'stream_writer': self.stream_writer, 'stream': self.stream,
'terminal_width': self.terminal_width, 'terminal_width': self.terminal_width,
'formatter_class': FOO 'formatter_class': FOO
} }
@ -72,10 +72,9 @@ class ControlServerTestCase(TestCase):
def test_init(self): def test_init(self):
self.assertIsInstance(self.parser, ArgumentParser) self.assertIsInstance(self.parser, ArgumentParser)
self.assertEqual(self.stream_writer, self.parser._stream_writer) self.assertEqual(self.stream, self.parser._stream)
self.assertEqual(self.terminal_width, self.parser._terminal_width) self.assertEqual(self.terminal_width, self.parser._terminal_width)
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO) self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
self.assertFalse(getattr(self.parser, 'exit_on_error'))
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class')) self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
self.assertSetEqual(set(), self.parser._flags) self.assertSetEqual(set(), self.parser._flags)
self.assertIsNone(self.parser._commands) self.assertIsNone(self.parser._commands)
@ -89,7 +88,7 @@ class ControlServerTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'foo-bar' expected_name = 'foo-bar'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
to_omit = ['abc', 'xyz'] to_omit = ['abc', 'xyz']
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs) output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
@ -107,7 +106,7 @@ class ControlServerTestCase(TestCase):
mock_get_first_doc_line.return_value = mock_help = 'help 123' mock_get_first_doc_line.return_value = mock_help = 'help 123'
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR} kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
expected_name = 'get-prop' expected_name = 'get-prop'
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_called_once_with(get_prop) mock_get_first_doc_line.assert_called_once_with(get_prop)
@ -119,7 +118,7 @@ class ControlServerTestCase(TestCase):
prop = property(get_prop, set_prop) prop = property(get_prop, set_prop)
expected_help = f"Get/set the `.{expected_name}` property" expected_help = f"Get/set the `.{expected_name}` property"
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
output = self.parser.add_property_command(prop, **kwargs) output = self.parser.add_property_command(prop, **kwargs)
self.assertEqual(mock_subparser, output) self.assertEqual(mock_subparser, output)
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)]) mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
@ -152,8 +151,7 @@ class ControlServerTestCase(TestCase):
mock_subparser = MagicMock(set_defaults=mock_set_defaults) mock_subparser = MagicMock(set_defaults=mock_set_defaults)
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
x = 'x' x = 'x'
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer, common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
expected_output = {'method': mock_subparser, 'prop': mock_subparser} expected_output = {'method': mock_subparser, 'prop': mock_subparser}
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x) output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
self.assertDictEqual(expected_output, output) self.assertDictEqual(expected_output, output)
@ -170,12 +168,12 @@ class ControlServerTestCase(TestCase):
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs) mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
def test__print_message(self): def test__print_message(self):
self.stream_writer.write = MagicMock() self.stream.write = MagicMock()
self.assertIsNone(self.parser._print_message('')) self.assertIsNone(self.parser._print_message(''))
self.stream_writer.write.assert_not_called() self.stream.write.assert_not_called()
msg = 'foo bar baz' msg = 'foo bar baz'
self.assertIsNone(self.parser._print_message(msg)) self.assertIsNone(self.parser._print_message(msg))
self.stream_writer.write.assert_called_once_with(msg.encode()) self.stream.write.assert_called_once_with(msg)
@patch.object(parser.ControlParser, '_print_message') @patch.object(parser.ControlParser, '_print_message')
def test_exit(self, mock__print_message: MagicMock): def test_exit(self, mock__print_message: MagicMock):
@ -265,12 +263,36 @@ class ControlServerTestCase(TestCase):
class RestTestCase(TestCase): class RestTestCase(TestCase):
log_lvl: int
@classmethod
def setUpClass(cls) -> None:
cls.log_lvl = parser.log.level
parser.log.setLevel(999)
@classmethod
def tearDownClass(cls) -> None:
parser.log.setLevel(cls.log_lvl)
def test__get_arg_type_wrapper(self): def test__get_arg_type_wrapper(self):
type_wrap = parser._get_arg_type_wrapper(int) type_wrap = parser._get_arg_type_wrapper(int)
self.assertEqual('int', type_wrap.__name__) self.assertEqual('int', type_wrap.__name__)
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS)) self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
self.assertEqual(13, type_wrap('13')) self.assertEqual(13, type_wrap('13'))
name = 'abcdef'
mock_type = MagicMock(side_effect=[parser.ArgumentTypeError, TypeError, ValueError, Exception], __name__=name)
type_wrap = parser._get_arg_type_wrapper(mock_type)
self.assertEqual(name, type_wrap.__name__)
with self.assertRaises(parser.ArgumentTypeError):
type_wrap(FOO)
with self.assertRaises(TypeError):
type_wrap(FOO)
with self.assertRaises(ValueError):
type_wrap(FOO)
with self.assertRaises(parser.ArgumentTypeError):
type_wrap(FOO)
@patch.object(parser, '_get_arg_type_wrapper') @patch.object(parser, '_get_arg_type_wrapper')
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock): def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR

View File

@ -183,7 +183,7 @@ class UnixControlServerTestCase(IsolatedAsyncioTestCase):
self.mock_pool = MagicMock() self.mock_pool = MagicMock()
self.path = '/tmp/asyncio_taskpool' self.path = '/tmp/asyncio_taskpool'
self.kwargs = {FOO: 123, BAR: 456} self.kwargs = {FOO: 123, BAR: 456}
self.server = server.UnixControlServer(pool=self.mock_pool, path=self.path, **self.kwargs) self.server = server.UnixControlServer(pool=self.mock_pool, socket_path=self.path, **self.kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
self.base_init_patcher.stop() self.base_init_patcher.stop()

View File

@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
import json import json
from argparse import ArgumentError, Namespace from argparse import ArgumentError, Namespace
from io import StringIO
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch, call from unittest.mock import AsyncMock, MagicMock, patch, call
from asyncio_taskpool.control import session from asyncio_taskpool.control import session
from asyncio_taskpool.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
from asyncio_taskpool.exceptions import HelpRequested from asyncio_taskpool.exceptions import HelpRequested
from asyncio_taskpool.pool import SimpleTaskPool from asyncio_taskpool.pool import SimpleTaskPool
@ -61,18 +62,19 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
self.assertEqual(self.mock_reader, self.session._reader) self.assertEqual(self.mock_reader, self.session._reader)
self.assertEqual(self.mock_writer, self.session._writer) self.assertEqual(self.mock_writer, self.session._writer)
self.assertIsNone(self.session._parser) self.assertIsNone(self.session._parser)
self.assertIsInstance(self.session._response_buffer, StringIO)
@patch.object(session, 'return_or_exception') @patch.object(session, 'return_or_exception')
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock): async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
def method(self, arg1, arg2, *var_args, **rest): pass def method(self, arg1, arg2, *var_args, **rest): pass
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11} test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
mock_return_or_exception.return_value = None mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs)) self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
mock_return_or_exception.assert_awaited_once_with( mock_return_or_exception.assert_awaited_once_with(
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
) )
self.mock_writer.write.assert_called_once_with(session.CMD_OK) self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
@patch.object(session, 'return_or_exception') @patch.object(session, 'return_or_exception')
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock): async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
@ -83,15 +85,16 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_return_or_exception.return_value = None mock_return_or_exception.return_value = None
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs)) self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs) mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
self.mock_writer.write.assert_called_once_with(session.CMD_OK) self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
mock_return_or_exception.reset_mock() mock_return_or_exception.reset_mock()
self.mock_writer.write.reset_mock() self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_return_or_exception.return_value = val = 420.69 mock_return_or_exception.return_value = val = 420.69
self.assertIsNone(await self.session._exec_property_and_respond(prop)) self.assertIsNone(await self.session._exec_property_and_respond(prop))
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool) mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
self.mock_writer.write.assert_called_once_with(str(val).encode()) self.assertEqual(str(val), self.session._response_buffer.getvalue())
@patch.object(session, 'ControlParser') @patch.object(session, 'ControlParser')
async def test_client_handshake(self, mock_parser_cls: MagicMock): async def test_client_handshake(self, mock_parser_cls: MagicMock):
@ -100,11 +103,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parser_cls.return_value = mock_parser mock_parser_cls.return_value = mock_parser
width = 5678 width = 5678
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' ' msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
mock_read = AsyncMock(return_value=msg.encode()) mock_readline = AsyncMock(return_value=msg.encode())
self.mock_reader.read = mock_read self.mock_reader.readline = mock_readline
self.mock_writer.drain = AsyncMock() self.mock_writer.drain = AsyncMock()
expected_parser_kwargs = { expected_parser_kwargs = {
STREAM_WRITER: self.mock_writer, 'stream': self.session._response_buffer,
CLIENT_INFO.TERMINAL_WIDTH: width, CLIENT_INFO.TERMINAL_WIDTH: width,
'prog': '', 'prog': '',
'usage': f'[-h] [{CMD}] ...' 'usage': f'[-h] [{CMD}] ...'
@ -115,11 +118,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
} }
self.assertIsNone(await self.session.client_handshake()) self.assertIsNone(await self.session.client_handshake())
self.assertEqual(mock_parser, self.session._parser) self.assertEqual(mock_parser, self.session._parser)
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES) mock_readline.assert_awaited_once_with()
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs) mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs) mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__) mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode()) self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
@patch.object(session.ControlSession, '_exec_property_and_respond') @patch.object(session.ControlSession, '_exec_property_and_respond')
@ -132,10 +135,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
kwargs = {FOO: BAR, 'hello': 'python'} kwargs = {FOO: BAR, 'hello': 'python'}
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs)) mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
self.session._parser = MagicMock(parse_args=mock_parse_args) self.session._parser = MagicMock(parse_args=mock_parse_args)
self.mock_writer.write = MagicMock()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs) mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
@ -145,7 +147,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs) mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs) mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
@ -161,47 +163,55 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
mock__exec_method_and_respond.assert_not_called() mock__exec_method_and_respond.assert_not_called()
mock__exec_property_and_respond.assert_not_called() mock__exec_property_and_respond.assert_not_called()
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_property_and_respond.reset_mock() mock__exec_property_and_respond.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.mock_writer.write.reset_mock() self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops") mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_called_once_with(str(exc).encode()) self.assertEqual(str(exc), self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
self.mock_writer.write.reset_mock()
mock_parse_args.reset_mock() mock_parse_args.reset_mock()
self.session._response_buffer.seek(0)
self.session._response_buffer.truncate()
mock_parse_args.side_effect = HelpRequested() mock_parse_args.side_effect = HelpRequested()
self.assertIsNone(await self.session._parse_command(msg)) self.assertIsNone(await self.session._parse_command(msg))
mock_parse_args.assert_called_once_with(msg.split(' ')) mock_parse_args.assert_called_once_with(msg.split(' '))
self.mock_writer.write.assert_not_called() self.assertEqual('', self.session._response_buffer.getvalue())
mock__exec_method_and_respond.assert_not_awaited() mock__exec_method_and_respond.assert_not_awaited()
mock__exec_property_and_respond.assert_not_awaited() mock__exec_property_and_respond.assert_not_awaited()
@patch.object(session.ControlSession, '_parse_command') @patch.object(session.ControlSession, '_parse_command')
async def test_listen(self, mock__parse_command: AsyncMock): async def test_listen(self, mock__parse_command: AsyncMock):
def make_reader_return_empty(): def make_reader_return_empty():
self.mock_reader.read.return_value = b'' self.mock_reader.readline.return_value = b''
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty) self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
msg = "fascinating" msg = "fascinating"
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode()) self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
response = FOO + BAR + FOO
self.session._response_buffer.write(response)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)]) self.mock_reader.readline.assert_has_awaits([call(), call()])
mock__parse_command.assert_awaited_once_with(msg) mock__parse_command.assert_awaited_once_with(msg)
self.assertEqual('', self.session._response_buffer.getvalue())
self.mock_writer.write.assert_called_once_with(response.encode() + b'\n')
self.mock_writer.drain.assert_awaited_once_with() self.mock_writer.drain.assert_awaited_once_with()
self.mock_reader.read.reset_mock() self.mock_reader.readline.reset_mock()
mock__parse_command.reset_mock() mock__parse_command.reset_mock()
self.mock_writer.write.reset_mock()
self.mock_writer.drain.reset_mock() self.mock_writer.drain.reset_mock()
self.mock_server.is_serving = MagicMock(return_value=False) self.mock_server.is_serving = MagicMock(return_value=False)
self.assertIsNone(await self.session.listen()) self.assertIsNone(await self.session.listen())
self.mock_reader.read.assert_not_awaited() self.mock_reader.readline.assert_not_awaited()
mock__parse_command.assert_not_awaited() mock__parse_command.assert_not_awaited()
self.mock_writer.write.assert_not_called()
self.mock_writer.drain.assert_not_awaited() self.mock_writer.drain.assert_not_awaited()

View File

View File

@ -21,10 +21,9 @@ Unittests for the `asyncio_taskpool.group_register` module.
from asyncio.locks import Lock from asyncio.locks import Lock
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock, patch
from asyncio_taskpool import group_register
from asyncio_taskpool.internals import group_register
FOO, BAR = 'foo', 'bar' FOO, BAR = 'foo', 'bar'

View File

@ -18,11 +18,12 @@ __doc__ = """
Unittests for the `asyncio_taskpool.helpers` module. Unittests for the `asyncio_taskpool.helpers` module.
""" """
import importlib
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase, TestCase
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
from asyncio_taskpool import helpers from asyncio_taskpool.internals import constants
from asyncio_taskpool.internals import helpers
class HelpersTestCase(IsolatedAsyncioTestCase): class HelpersTestCase(IsolatedAsyncioTestCase):
@ -81,12 +82,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
helpers.star_function(f, a, 123456789) helpers.star_function(f, a, 123456789)
async def test_join_queue(self):
mock_join = AsyncMock()
mock_queue = MagicMock(join=mock_join)
self.assertIsNone(await helpers.join_queue(mock_queue))
mock_join.assert_awaited_once_with()
def test_get_first_doc_line(self): def test_get_first_doc_line(self):
expected_output = 'foo bar baz' expected_output = 'foo bar baz'
mock_obj = MagicMock(__doc__=f"""{expected_output} mock_obj = MagicMock(__doc__=f"""{expected_output}
@ -128,3 +123,45 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
helpers.resolve_dotted_path('foo.bar.baz') helpers.resolve_dotted_path('foo.bar.baz')
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')]) mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
class ClassMethodWorkaroundTestCase(TestCase):
def test_init(self):
def func(): return 'foo'
def getter(): return 'bar'
prop = property(getter)
instance = helpers.ClassMethodWorkaround(func)
self.assertIs(func, instance._getter)
instance = helpers.ClassMethodWorkaround(prop)
self.assertIs(getter, instance._getter)
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
def test_get(self, _mock_init: MagicMock):
def func(x: MagicMock): return x.__name__
instance = helpers.ClassMethodWorkaround(MagicMock())
instance._getter = func
obj, cls = None, MagicMock
expected_output = 'MagicMock'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
obj = MagicMock(__name__='bar')
expected_output = 'bar'
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
cls = None
output = instance.__get__(obj, cls)
self.assertEqual(expected_output, output)
def test_correct_class(self):
is_older_python = constants.PYTHON_BEFORE_39
try:
constants.PYTHON_BEFORE_39 = True
importlib.reload(helpers)
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
constants.PYTHON_BEFORE_39 = False
importlib.reload(helpers)
self.assertIs(classmethod, helpers.classmethod)
finally:
constants.PYTHON_BEFORE_39 = is_older_python

View File

@ -19,15 +19,12 @@ Unittests for the `asyncio_taskpool.pool` module.
""" """
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.locks import Semaphore from asyncio.locks import Event, Semaphore
from asyncio.queues import QueueEmpty
from datetime import datetime
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
from typing import Type from typing import Type
from asyncio_taskpool import pool, exceptions from asyncio_taskpool import pool, exceptions
from asyncio_taskpool.constants import DATETIME_FORMAT
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set() EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
@ -86,17 +83,20 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.assertEqual(0, self.task_pool._num_started) self.assertEqual(0, self.task_pool._num_started)
self.assertFalse(self.task_pool._locked) self.assertFalse(self.task_pool._locked)
self.assertFalse(self.task_pool._closed) self.assertIsInstance(self.task_pool._closed, Event)
self.assertFalse(self.task_pool._closed.is_set())
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name) self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended) self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
self.assertIsInstance(self.task_pool._enough_room, Semaphore) self.assertIsInstance(self.task_pool._enough_room, Semaphore)
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups) self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
self.assertEqual(self.mock_idx, self.task_pool._idx) self.assertEqual(self.mock_idx, self.task_pool._idx)
self.mock__add_pool.assert_called_once_with(self.task_pool) self.mock__add_pool.assert_called_once_with(self.task_pool)
@ -163,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool.get_group_ids(group_name, 'something else') self.task_pool.get_group_ids(group_name, 'something else')
async def test__check_start(self): async def test__check_start(self):
self.task_pool._closed = True self.task_pool._closed.set()
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock() mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
try: try:
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@ -176,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._check_start(awaitable=None, function=mock_coroutine) self.task_pool._check_start(awaitable=None, function=mock_coroutine)
with self.assertRaises(exceptions.PoolIsClosed): with self.assertRaises(exceptions.PoolIsClosed):
self.task_pool._check_start(awaitable=mock_coroutine, function=None) self.task_pool._check_start(awaitable=mock_coroutine, function=None)
self.task_pool._closed = False self.task_pool._closed.clear()
self.task_pool._locked = True self.task_pool._locked = True
with self.assertRaises(exceptions.PoolIsLocked): with self.assertRaises(exceptions.PoolIsLocked):
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False) self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
@ -312,104 +312,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
self.task_pool._get_running_task(task_id) self.task_pool._get_running_task(task_id)
mock__task_name.assert_not_called() mock__task_name.assert_not_called()
@patch('warnings.warn')
def test__get_cancel_kw(self, mock_warn: MagicMock):
msg = None
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
msg = 'something'
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_called_once()
mock_warn.reset_mock()
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
mock_warn.assert_not_called()
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_get_running_task') @patch.object(pool.BaseTaskPool, '_get_running_task')
def test_cancel(self, mock__get_running_task: MagicMock): def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
task_id1, task_id2, task_id3 = 1, 4, 9 task_id1, task_id2, task_id3 = 1, 4, 9
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock() mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO)) self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)]) mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)]) mock__get_cancel_kw.assert_called_once_with(FOO)
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
def test__cancel_and_remove_all_from_group(self):
task_id = 555
mock_cancel = MagicMock()
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
class MockRegister(set, MagicMock):
pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
mock_cancel.assert_called_once_with(msg=FOO)
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[FOO] = mock_group_reg
with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool.cancel_group(BAR)
mock__cancel_and_remove_all_from_group.assert_not_called()
mock_grp_aenter.assert_not_called()
mock_grp_aexit.assert_not_called()
self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
self.task_pool._task_groups[BAR] = mock_group_reg
self.assertIsNone(await self.task_pool.cancel_all(FOO))
mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
mock_grp_aenter.assert_awaited_once_with()
mock_grp_aexit.assert_awaited_once()
async def test_flush(self):
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
async def test_gather_and_close(self):
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = running = {789: mock_running_func()}
with self.assertRaises(exceptions.PoolStillUnlocked):
await self.task_pool.gather_and_close()
self.assertDictEqual(ended, self.task_pool._tasks_ended)
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
self.assertDictEqual(running, self.task_pool._tasks_running)
self.assertListEqual(before_gather, self.task_pool._before_gathering)
self.assertFalse(self.task_pool._closed)
self.task_pool._locked = True
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_before_gather.assert_awaited_once_with()
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
self.assertTrue(self.task_pool._closed)
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
self.base_class_init = self.base_class_init_patcher.start()
super().setUp()
def tearDown(self) -> None:
self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self):
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
def test__cancel_group_meta_tasks(self): def test__cancel_group_meta_tasks(self):
mock_task1, mock_task2 = MagicMock(), MagicMock() mock_task1, mock_task2 = MagicMock(), MagicMock()
@ -426,26 +354,48 @@ class TaskPoolTestCase(CommonTestCase):
mock_task1.cancel.assert_called_once_with() mock_task1.cancel.assert_called_once_with()
mock_task2.cancel.assert_called_once_with() mock_task2.cancel.assert_called_once_with()
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
kw = {BAR: 10, BAZ: 20}
task_id = 555
mock_cancel = MagicMock()
def add_mock_task_to_running(_):
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
# We add the fake task to the `_tasks_running` dictionary as a side effect of calling the mocked method,
# to verify that it is called first, before the cancellation loop starts.
mock__cancel_group_meta_tasks.side_effect = add_mock_task_to_running
class MockRegister(set, MagicMock):
pass
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
mock_cancel.assert_called_once_with(**kw)
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group') @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
@patch.object(pool.TaskPool, '_cancel_group_meta_tasks') def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock, mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_base__cancel_and_remove_all_from_group: MagicMock): self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
group_name, group_reg, msg = 'xyz', MagicMock(), FOO with self.assertRaises(exceptions.InvalidGroupName):
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)) self.task_pool.cancel_group(BAR)
mock__cancel_group_meta_tasks.assert_called_once_with(group_name) mock__cancel_and_remove_all_from_group.assert_not_called()
mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg) self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
mock__get_cancel_kw.assert_called_once_with(BAR)
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
@patch.object(pool.BaseTaskPool, 'cancel_group') @patch.object(pool.BaseTaskPool, '_get_cancel_kw')
async def test_cancel_group(self, mock_base_cancel_group: AsyncMock): @patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
group_name, msg = 'abc', 'xyz' def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
await self.task_pool.cancel_group(group_name, msg=msg) mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg) mock_group_reg = MagicMock()
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
@patch.object(pool.BaseTaskPool, 'cancel_all') self.assertIsNone(self.task_pool.cancel_all(BAZ))
async def test_cancel_all(self, mock_base_cancel_all: AsyncMock): mock__get_cancel_kw.assert_called_once_with(BAZ)
msg = 'xyz' mock__cancel_and_remove_all_from_group.assert_has_calls([
await self.task_pool.cancel_all(msg=msg) call(BAR, mock_group_reg, **fake_cancel_kw),
mock_base_cancel_all.assert_awaited_once_with(msg=msg) call(FOO, mock_group_reg, **fake_cancel_kw)
])
def test__pop_ended_meta_tasks(self): def test__pop_ended_meta_tasks(self):
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True) mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
@ -457,131 +407,162 @@ class TaskPoolTestCase(CommonTestCase):
self.assertSetEqual(expected_output, output) self.assertSetEqual(expected_output, output)
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running) self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
@patch.object(pool.TaskPool, '_pop_ended_meta_tasks') @patch.object(pool.BaseTaskPool, '_pop_ended_meta_tasks')
@patch.object(pool.BaseTaskPool, 'flush') async def test_flush(self, mock__pop_ended_meta_tasks: MagicMock):
async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock): # Meta tasks:
mock_ended_meta_task = AsyncMock() mock_ended_meta_task = AsyncMock()
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()} mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=False)) # Actual tasks:
mock_base_flush.assert_awaited_once_with(return_exceptions=False) mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
# Meta tasks:
mock__pop_ended_meta_tasks.assert_called_once_with() mock__pop_ended_meta_tasks.assert_called_once_with()
mock_ended_meta_task.assert_awaited_once_with() mock_ended_meta_task.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with() mock_cancelled_meta_task.assert_awaited_once_with()
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
# Actual tasks:
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
@patch.object(pool.BaseTaskPool, 'gather_and_close') @patch.object(pool.BaseTaskPool, 'lock')
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock): async def test_gather_and_close(self, mock_lock: MagicMock):
# Meta tasks:
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock() mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}} self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError) mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()} self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
# Actual tasks:
mock_running_func = AsyncMock()
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
self.task_pool._tasks_ended = {123: mock_ended_func()}
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
self.task_pool._tasks_running = {789: mock_running_func()}
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True)) self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
mock_lock.assert_called_once_with()
# Meta tasks:
mock_meta_task1.assert_awaited_once_with() mock_meta_task1.assert_awaited_once_with()
mock_meta_task2.assert_awaited_once_with() mock_meta_task2.assert_awaited_once_with()
mock_cancelled_meta_task.assert_awaited_once_with() mock_cancelled_meta_task.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running) self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled) self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
# Actual tasks:
mock_ended_func.assert_awaited_once_with()
mock_cancelled_func.assert_awaited_once_with()
mock_running_func.assert_awaited_once_with()
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
self.assertTrue(self.task_pool._closed.is_set())
@patch.object(pool, 'datetime') async def test_until_closed(self):
def test__generate_group_name(self, mock_datetime: MagicMock): self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
output = await self.task_pool.until_closed()
self.assertEqual(FOO, output)
self.task_pool._closed.wait.assert_awaited_once_with()
class TaskPoolTestCase(CommonTestCase):
TEST_CLASS = pool.TaskPool
task_pool: pool.TaskPool
def test__generate_group_name(self):
prefix, func = 'x y z', AsyncMock(__name__=BAR) prefix, func = 'x y z', AsyncMock(__name__=BAR)
dt = datetime(1776, 7, 4, 0, 0, 1) base_name = f'{prefix}-{BAR}-group'
mock_datetime.now = MagicMock(return_value=dt) self.task_pool._task_groups = {
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}' f'{base_name}-0': MagicMock(),
output = pool.TaskPool._generate_group_name(prefix, func) f'{base_name}-1': MagicMock(),
f'{base_name}-100': MagicMock(),
}
expected_output = f'{base_name}-2'
output = self.task_pool._generate_group_name(prefix, func)
self.assertEqual(expected_output, output) self.assertEqual(expected_output, output)
@patch.object(pool.TaskPool, '_start_task') @patch.object(pool.TaskPool, '_start_task')
async def test__apply_num(self, mock__start_task: AsyncMock): async def test__apply_spawner(self, mock__start_task: AsyncMock):
group_name = FOO + BAR grp_name = FOO + BAR
mock_awaitable = object() mock_awaitable1, mock_awaitable2 = object(), object()
mock_func = MagicMock(return_value=mock_awaitable) mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3 args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
mock_func.assert_has_calls(3 * [call(*args, **kwargs)]) mock_func.assert_has_calls(num * [call(*args, **kw)])
mock__start_task.assert_has_awaits(3 * [ mock__start_task.assert_has_awaits([
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
]) ])
mock_func.reset_mock() mock_func.reset_mock(side_effect=True)
mock__start_task.reset_mock() mock__start_task.reset_mock()
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb)) # Simulate cancellation while the second task is being started.
mock_func.assert_has_calls(num * [call(*args)]) mock__start_task.side_effect = [None, CancelledError, None]
mock__start_task.assert_has_awaits(num * [ mock_coroutine_to_close = MagicMock()
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb) mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
mock_func.assert_has_calls(2 * [call(*args)])
mock__start_task.assert_has_awaits([
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
]) ])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock()) @patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister') @patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.TaskPool, '_generate_group_name')
@patch.object(pool.BaseTaskPool, '_check_start') @patch.object(pool.BaseTaskPool, '_check_start')
async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock, def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock): mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 123' mock__generate_group_name.return_value = generated_name = 'name 123'
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__apply_num.return_value = mock_apply_coroutine = object() mock__apply_spawner.return_value = mock_apply_coroutine = object()
mock_task_future = AsyncMock() mock_create_task.return_value = fake_task = object()
mock_create_task.return_value = mock_task_future()
mock_func, num, group_name = MagicMock(), 3, FOO + BAR mock_func, num, group_name = MagicMock(), 3, FOO + BAR
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2} args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.task_pool._task_groups = {group_name: 'causes error'}
with self.assertRaises(exceptions.InvalidGroupName):
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=mock_func)
mock__apply_spawner.assert_not_called()
mock_create_task.assert_not_called()
mock__check_start.reset_mock()
self.task_pool._task_groups = {} self.task_pool._task_groups = {}
def check_assertions(_group_name, _output): def check_assertions(_group_name, _output):
self.assertEqual(_group_name, _output) self.assertEqual(_group_name, _output)
mock__check_start.assert_called_once_with(function=mock_func) mock__check_start.assert_called_once_with(function=mock_func)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name]) self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
mock_group_reg.__aenter__.assert_awaited_once_with() mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock_create_task.assert_called_once_with(mock_apply_coroutine) mock_create_task.assert_called_once_with(mock_apply_coroutine)
mock_group_reg.__aexit__.assert_awaited_once() self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
mock_task_future.assert_awaited_once_with()
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb) output = self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
check_assertions(group_name, output) check_assertions(group_name, output)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__check_start.reset_mock() mock__check_start.reset_mock()
self.task_pool._task_groups.clear() self.task_pool._task_groups.clear()
mock_group_reg.__aenter__.reset_mock() mock__apply_spawner.reset_mock()
mock__apply_num.reset_mock()
mock_create_task.reset_mock() mock_create_task.reset_mock()
mock_group_reg.__aexit__.reset_mock()
mock_task_future = AsyncMock()
mock_create_task.return_value = mock_task_future()
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb) output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
check_assertions(generated_name, output) check_assertions(generated_name, output)
mock__generate_group_name.assert_called_once_with('apply', mock_func) mock__generate_group_name.assert_called_once_with('apply', mock_func)
@patch.object(pool, 'Queue')
async def test__queue_producer(self, mock_queue_cls: MagicMock):
mock_put = AsyncMock()
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
item1, item2, item3 = FOO, 420, 69
arg_iter = iter([item1, item2, item3])
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
with self.assertRaises(StopIteration):
next(arg_iter)
mock_put.reset_mock()
mock_put.side_effect = [CancelledError, None]
arg_iter = iter([item1, item2, item3])
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
mock_queue.item_processed.assert_has_calls([call(), call()])
self.assertListEqual([item2, item3], list(arg_iter))
@patch.object(pool, 'execute_optional') @patch.object(pool, 'execute_optional')
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock): async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
semaphore, mock_end_cb = Semaphore(1), MagicMock() semaphore, mock_end_cb = Semaphore(1), MagicMock()
@ -597,144 +578,176 @@ class TaskPoolTestCase(CommonTestCase):
@patch.object(pool, 'Semaphore') @patch.object(pool, 'Semaphore')
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock, async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
mock__start_task: AsyncMock, mock_star_function: MagicMock): mock__start_task: AsyncMock, mock_star_function: MagicMock):
mock_semaphore_cls.return_value = semaphore = Semaphore(3) n = 2
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
mock__get_map_end_callback.return_value = map_cb = MagicMock() mock__get_map_end_callback.return_value = map_cb = MagicMock()
awaitable = 'totally an awaitable' awaitable1, awaitable2 = 'totally an awaitable', object()
mock_star_function.side_effect = [awaitable, awaitable, Exception()] mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
arg1, arg2, bad = 123456789, 'function argument', None arg1, arg2, bad = 123456789, 'function argument', None
mock_q_maxsize = 3 args = [arg1, bad, arg2]
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]), grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb)) self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then # We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3,
# at the end of the loop, we expect it be locked.
self.assertTrue(semaphore.locked()) self.assertTrue(semaphore.locked())
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize) mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb) mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
mock__start_task.assert_has_awaits(2 * [ mock__start_task.assert_has_awaits([
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb) call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
]) ])
mock_star_function.assert_has_calls([ mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars), call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars), call(mock_func, bad, arg_stars=stars),
call(mock_func, bad, arg_stars=stars) call(mock_func, arg2, arg_stars=stars)
]) ])
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while acquiring the semaphore:
mock_acquire = AsyncMock(side_effect=[True, CancelledError])
mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
mock_star_function.return_value = mock_coroutine = MagicMock()
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
end_callback=map_cb, cancel_callback=cancel_cb)
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_not_called()
self.assertEqual(FOO, next(arg_it))
mock_acquire.reset_mock(side_effect=True)
mock_semaphore_cls.reset_mock()
mock__get_map_end_callback.reset_mock()
mock__start_task.reset_mock()
mock_star_function.reset_mock(side_effect=True)
# With a CancelledError thrown while starting the task:
mock__start_task.side_effect = [None, CancelledError]
arg_it = iter(arg for arg in (arg1, arg2, FOO))
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
mock_semaphore_cls.assert_called_once_with(n)
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
mock_star_function.assert_has_calls([
call(mock_func, arg1, arg_stars=stars),
call(mock_func, arg2, arg_stars=stars)
])
mock_acquire.assert_has_awaits([call(), call()])
mock__start_task.assert_has_awaits(2 * [
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
])
mock_coroutine.close.assert_called_once_with()
mock_semaphore.release.assert_called_once_with()
self.assertEqual(FOO, next(arg_it))
@patch.object(pool, 'create_task') @patch.object(pool, 'create_task')
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock) @patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
@patch.object(pool, 'join_queue', new_callable=MagicMock)
@patch.object(pool, 'Queue')
@patch.object(pool, 'TaskGroupRegister') @patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start') @patch.object(pool.BaseTaskPool, '_check_start')
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock, def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock,
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock, mock_create_task: MagicMock):
mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls) mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock_queue_cls.return_value = mock_q = MagicMock() mock__arg_consumer.return_value = fake_consumer = object()
mock_join_queue.return_value = fake_join = object() mock_create_task.return_value = fake_task = object()
mock__queue_producer.return_value = fake_producer = object()
mock__queue_consumer.return_value = fake_consumer = object()
fake_task1, fake_task2 = object(), object()
mock_create_task.side_effect = [fake_task1, fake_task2]
group_name, group_size = 'onetwothree', 0 group_name, n = 'onetwothree', 0
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3 func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock() mock__check_start.reset_mock()
group_size = 1234 n = 1234
self.task_pool._task_groups = {group_name: MagicMock()} self.task_pool._task_groups = {group_name: MagicMock()}
with self.assertRaises(exceptions.InvalidGroupName): with self.assertRaises(exceptions.InvalidGroupName):
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb) self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock__check_start.reset_mock() mock__check_start.reset_mock()
self.task_pool._task_groups.clear() self.task_pool._task_groups.clear()
self.task_pool._before_gathering = []
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)) self.assertIsNone(self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb))
mock__check_start.assert_called_once_with(function=func) mock__check_start.assert_called_once_with(function=func)
mock_reg_cls.assert_called_once_with() mock_reg_cls.assert_called_once_with()
self.task_pool._task_groups[group_name] = mock_group_reg self.task_pool._task_groups[group_name] = mock_group_reg
mock_group_reg.__aenter__.assert_awaited_once_with() mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars,
mock_queue_cls.assert_called_once_with(maxsize=group_size) end_callback=end_cb, cancel_callback=cancel_cb)
mock_join_queue.assert_called_once_with(mock_q) mock_create_task.assert_called_once_with(fake_consumer)
self.assertListEqual([fake_join], self.task_pool._before_gathering) self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
mock__queue_producer.assert_called_once()
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
mock_group_reg.__aexit__.assert_awaited_once()
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
@patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.TaskPool, '_generate_group_name')
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): def test_map(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb) output = self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0, mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb) output = self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0, mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('map', mock_func) mock__generate_group_name.assert_called_once_with('map', mock_func)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
@patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.TaskPool, '_generate_group_name')
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb) output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1, mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb) output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1, mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, args_iter, 1,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('starmap', mock_func) mock__generate_group_name.assert_called_once_with('starmap', mock_func)
@patch.object(pool.TaskPool, '_map') @patch.object(pool.TaskPool, '_map')
@patch.object(pool.TaskPool, '_generate_group_name') @patch.object(pool.TaskPool, '_generate_group_name')
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock): async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
mock__generate_group_name.return_value = generated_name = 'name 1 2 3' mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
mock_func = MagicMock() mock_func = MagicMock()
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
end_cb, cancel_cb = MagicMock(), MagicMock() end_cb, cancel_cb = MagicMock(), MagicMock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb) output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb)
self.assertEqual(group_name, output) self.assertEqual(group_name, output)
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2, mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, kw_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_not_called() mock__generate_group_name.assert_not_called()
mock__map.reset_mock() mock__map.reset_mock()
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb) output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb)
self.assertEqual(generated_name, output) self.assertEqual(generated_name, output)
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2, mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2,
end_callback=end_cb, cancel_callback=cancel_cb) end_callback=end_cb, cancel_callback=cancel_cb)
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func) mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
@ -749,13 +762,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
TEST_POOL_CANCEL_CB = MagicMock() TEST_POOL_CANCEL_CB = MagicMock()
def get_task_pool_init_params(self) -> dict: def get_task_pool_init_params(self) -> dict:
return super().get_task_pool_init_params() | { params = super().get_task_pool_init_params()
params.update({
'func': self.TEST_POOL_FUNC, 'func': self.TEST_POOL_FUNC,
'args': self.TEST_POOL_ARGS, 'args': self.TEST_POOL_ARGS,
'kwargs': self.TEST_POOL_KWARGS, 'kwargs': self.TEST_POOL_KWARGS,
'end_callback': self.TEST_POOL_END_CB, 'end_callback': self.TEST_POOL_END_CB,
'cancel_callback': self.TEST_POOL_CANCEL_CB, 'cancel_callback': self.TEST_POOL_CANCEL_CB,
} })
return params
def setUp(self) -> None: def setUp(self) -> None:
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__') self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
@ -764,6 +779,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
def tearDown(self) -> None: def tearDown(self) -> None:
self.base_class_init_patcher.stop() self.base_class_init_patcher.stop()
super().tearDown()
def test_init(self): def test_init(self):
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func) self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
@ -780,23 +796,54 @@ class SimpleTaskPoolTestCase(CommonTestCase):
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name) self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
@patch.object(pool.SimpleTaskPool, '_start_task') @patch.object(pool.SimpleTaskPool, '_start_task')
async def test__start_one(self, mock__start_task: AsyncMock): async def test__start_num(self, mock__start_task: AsyncMock):
mock__start_task.return_value = expected_output = 99 group_name = FOO + BAR + 'abc'
self.task_pool._func = MagicMock(return_value=BAR) mock_awaitable1, mock_awaitable2 = object(), object()
output = await self.task_pool._start_one() self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
self.assertEqual(expected_output, output) num = 3
self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs) self.assertIsNone(await self.task_pool._start_num(num, group_name))
mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback, self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
cancel_callback=self.task_pool._cancel_callback) call_kw = {
'group_name': group_name,
'end_callback': self.task_pool._end_callback,
'cancel_callback': self.task_pool._cancel_callback
}
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
@patch.object(pool.SimpleTaskPool, '_start_one') self.task_pool._func.reset_mock(side_effect=True)
async def test_start(self, mock__start_one: AsyncMock): mock__start_task.reset_mock()
mock__start_one.return_value = FOO
# Simulate cancellation while the second task is being started.
mock__start_task.side_effect = [None, CancelledError, None]
mock_coroutine_to_close = MagicMock()
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
self.assertIsNone(await self.task_pool._start_num(num, group_name))
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
mock_coroutine_to_close.close.assert_called_once_with()
@patch.object(pool, 'create_task')
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
@patch.object(pool, 'TaskGroupRegister')
@patch.object(pool.BaseTaskPool, '_check_start')
def test_start(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__start_num: AsyncMock,
mock_create_task: MagicMock):
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
mock__start_num.return_value = mock_start_num_coroutine = object()
mock_create_task.return_value = fake_task = object()
self.task_pool._task_groups = {}
self.task_pool._group_meta_tasks_running = {}
num = 5 num = 5
output = await self.task_pool.start(num) self.task_pool._start_calls = 42
expected_output = num * [FOO] expected_group_name = 'start-group-42'
self.assertListEqual(expected_output, output) output = self.task_pool.start(num)
mock__start_one.assert_has_awaits(num * [call()]) self.assertEqual(expected_group_name, output)
mock__check_start.assert_called_once_with(function=self.TEST_POOL_FUNC)
self.assertEqual(43, self.task_pool._start_calls)
self.assertEqual(mock_group_reg, self.task_pool._task_groups[expected_group_name])
mock__start_num.assert_called_once_with(num, expected_group_name)
mock_create_task.assert_called_once_with(mock_start_num_coroutine)
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[expected_group_name])
@patch.object(pool.SimpleTaskPool, 'cancel') @patch.object(pool.SimpleTaskPool, 'cancel')
def test_stop(self, mock_cancel: MagicMock): def test_stop(self, mock_cancel: MagicMock):

View File

@ -39,12 +39,11 @@ async def work(n: int) -> None:
async def main() -> None: async def main() -> None:
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
await pool.start(3) # launches work tasks 0, 1, and 2 pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
await pool.start() # launches work task 3 pool.start(1) # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2 (LIFO order) pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
pool.lock() # required for the last line
await pool.gather_and_close() # awaits all tasks, then flushes the pool await pool.gather_and_close() # awaits all tasks, then flushes the pool
@ -123,7 +122,7 @@ async def main() -> None:
pool = TaskPool(3) pool = TaskPool(3)
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments). # Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
print("> Called `apply`") print("> Called `apply`")
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2) pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
# Let the tasks work for a bit. # Let the tasks work for a bit.
await asyncio.sleep(1.5) await asyncio.sleep(1.5)
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different # Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
@ -135,11 +134,9 @@ async def main() -> None:
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5) # Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
# only once there is room in the pool and no more than one other task of these new ones is running. # only once there is room in the pool and no more than one other task of these new ones is running.
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)] args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
await pool.starmap(other_work, args_list, group_size=2) pool.starmap(other_work, args_list, num_concurrent=2)
print("> Called `starmap`") print("> Called `starmap`")
# Now we lock the pool, so that we can safely await all our tasks. # We block, until all tasks have ended.
pool.lock()
# Finally, we block, until all tasks have ended.
print("> Calling `gather_and_close`...") print("> Calling `gather_and_close`...")
await pool.gather_and_close() await pool.gather_and_close()
print("> Done.") print("> Done.")
@ -199,7 +196,7 @@ Started TaskPool-0_Task-3
> other_work with 15 > other_work with 15
Ended TaskPool-0_Task-0 Ended TaskPool-0_Task-0
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start Started TaskPool-0_Task-4 <--- since `num_concurrent` is set to 2, Task-5 will not start
> work with 190 > work with 190
> work with 190 > work with 190
> other_work with 16 > other_work with 16

View File

@ -23,8 +23,9 @@ Use the main CLI client to interface at the socket.
import asyncio import asyncio
import logging import logging
from asyncio_taskpool import SimpleTaskPool, TCPControlServer from asyncio_taskpool import SimpleTaskPool
from asyncio_taskpool.constants import PACKAGE_NAME from asyncio_taskpool.control import TCPControlServer
from asyncio_taskpool.internals.constants import PACKAGE_NAME
logging.getLogger().setLevel(logging.NOTSET) logging.getLogger().setLevel(logging.NOTSET)
@ -66,7 +67,7 @@ async def main() -> None:
for item in range(100): for item in range(100):
q.put_nowait(item) q.put_nowait(item)
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
await pool.start(3) # launches three worker tasks pool.start(3) # launches three worker tasks
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever() control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue. # We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join() await q.join()
@ -74,7 +75,6 @@ async def main() -> None:
control_server_task.cancel() control_server_task.cancel()
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left, # Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
# we can now safely cancel their tasks. # we can now safely cancel their tasks.
pool.lock()
pool.stop_all() pool.stop_all()
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled. # Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions, # We block until they all return or raise an exception, but since we are not interested in any of their exceptions,