generated from daniil-berg/boilerplate-py
Compare commits
16 Commits
Author | SHA1 | Date | |
---|---|---|---|
ee0b8c0002
|
|||
28c997e0ee
|
|||
5a72a6d1d1
|
|||
72e380cd77
|
|||
85672bddeb
|
|||
dae883446a
|
|||
a4ecf39157
|
|||
e3bbb05eac
|
|||
36527ccffc
|
|||
d047b99119
|
|||
d7cd16c540
|
|||
db306a1a1f
|
|||
3a8fcb2d5a
|
|||
cf02206588
|
|||
0796038dcd
|
|||
f4e33baf82
|
@ -1,14 +1,12 @@
|
|||||||
[run]
|
[run]
|
||||||
source = src/
|
source = src/
|
||||||
branch = true
|
branch = true
|
||||||
omit =
|
command_line = -m unittest discover
|
||||||
.venv/*
|
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
|
fail_under = 100
|
||||||
show_missing = True
|
show_missing = True
|
||||||
skip_covered = False
|
skip_covered = False
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if __name__ == ['"]__main__['"]:
|
if __name__ == ['"]__main__['"]:
|
||||||
omit =
|
|
||||||
tests/*
|
|
||||||
|
88
.github/workflows/main.yaml
vendored
Normal file
88
.github/workflows/main.yaml
vendored
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
name: CI
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master]
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
name: Python ${{ matrix.python-version }} Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- '3.8'
|
||||||
|
- '3.9'
|
||||||
|
- '3.10'
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements/dev.txt'
|
||||||
|
|
||||||
|
- name: Upgrade packaging tools
|
||||||
|
run: pip install -U pip
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -U -r requirements/dev.txt
|
||||||
|
|
||||||
|
- name: Install asyncio-taskpool
|
||||||
|
run: pip install -e .
|
||||||
|
|
||||||
|
- name: Run tests for Python ${{ matrix.python-version }}
|
||||||
|
if: ${{ matrix.python-version != '3.10' }}
|
||||||
|
run: python -m tests
|
||||||
|
|
||||||
|
- name: Run tests for Python 3.10 and save coverage
|
||||||
|
if: ${{ matrix.python-version == '3.10' }}
|
||||||
|
run: echo "coverage=$(./coverage.sh)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
coverage: ${{ env.coverage }}
|
||||||
|
|
||||||
|
update_badges:
|
||||||
|
needs: tests
|
||||||
|
name: Update Badges
|
||||||
|
env:
|
||||||
|
meta_gist_id: 3f8240a976e8781a765d9c74a583dcda
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Download `cloc`
|
||||||
|
run: sudo apt-get update -y && sudo apt-get install -y cloc
|
||||||
|
|
||||||
|
- name: Count lines of code/comments
|
||||||
|
run: |
|
||||||
|
echo "cloc_code=$(./cloc.sh -c src/)" >> $GITHUB_ENV
|
||||||
|
echo "cloc_comments=$(./cloc.sh -m src/)" >> $GITHUB_ENV
|
||||||
|
echo "cloc_commentpercent=$(./cloc.sh -p src/)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Create badge for lines of code
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: cloc-code.json
|
||||||
|
label: Lines of Code
|
||||||
|
message: ${{ env.cloc_code }}
|
||||||
|
|
||||||
|
- name: Create badge for lines of comments
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: cloc-comments.json
|
||||||
|
label: Comments
|
||||||
|
message: ${{ env.cloc_comments }} (${{ env.cloc_commentpercent }}%)
|
||||||
|
|
||||||
|
- name: Create badge for test coverage
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: test-coverage.json
|
||||||
|
label: Coverage
|
||||||
|
message: ${{ needs.tests.outputs.coverage }}
|
11
.readthedocs.yaml
Normal file
11
.readthedocs.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
version: 2
|
||||||
|
build:
|
||||||
|
os: 'ubuntu-20.04'
|
||||||
|
tools:
|
||||||
|
python: '3.8'
|
||||||
|
python:
|
||||||
|
install:
|
||||||
|
- method: pip
|
||||||
|
path: .
|
||||||
|
sphinx:
|
||||||
|
fail_on_warning: true
|
36
README.md
36
README.md
@ -1,7 +1,30 @@
|
|||||||
|
[//]: # (This file is part of asyncio-taskpool.)
|
||||||
|
|
||||||
|
[//]: # (asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of)
|
||||||
|
[//]: # (version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.)
|
||||||
|
|
||||||
|
[//]: # (asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;)
|
||||||
|
[//]: # (without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.)
|
||||||
|
[//]: # (See the GNU Lesser General Public License for more details.)
|
||||||
|
|
||||||
|
[//]: # (You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.)
|
||||||
|
[//]: # (If not, see <https://www.gnu.org/licenses/>.)
|
||||||
|
|
||||||
# asyncio-taskpool
|
# asyncio-taskpool
|
||||||
|
|
||||||
|
[![GitHub last commit][github-last-commit-img]][github-last-commit]
|
||||||
|
![Lines of code][gist-cloc-code-img]
|
||||||
|
![Lines of comments][gist-cloc-comments-img]
|
||||||
|
![Test coverage][gist-test-coverage-img]
|
||||||
|
[![License: LGPL v3.0][lgpl3-img]][lgpl3]
|
||||||
|
[![PyPI version][pypi-latest-version-img]][pypi-latest-version]
|
||||||
|
|
||||||
**Dynamically manage pools of asyncio tasks**
|
**Dynamically manage pools of asyncio tasks**
|
||||||
|
|
||||||
|
Full documentation available at [RtD](https://asyncio-taskpool.readthedocs.io/en/latest).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
- [Contents](#contents)
|
- [Contents](#contents)
|
||||||
- [Summary](#summary)
|
- [Summary](#summary)
|
||||||
@ -55,8 +78,7 @@ Python Version 3.8+, tested on Linux
|
|||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
|
Install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`, then execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and save the coverage report.
|
||||||
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
@ -67,3 +89,13 @@ The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COP
|
|||||||
---
|
---
|
||||||
|
|
||||||
© 2022 Daniil Fajnberg
|
© 2022 Daniil Fajnberg
|
||||||
|
|
||||||
|
[github-last-commit]: https://github.com/daniil-berg/asyncio-taskpool/commits
|
||||||
|
[github-last-commit-img]: https://img.shields.io/github/last-commit/daniil-berg/asyncio-taskpool?label=Last%20commit&logo=git&
|
||||||
|
[gist-cloc-code-img]: https://img.shields.io/endpoint?logo=python&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-code.json
|
||||||
|
[gist-cloc-comments-img]: https://img.shields.io/endpoint?logo=sharp&color=lightgrey&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-comments.json
|
||||||
|
[gist-test-coverage-img]: https://img.shields.io/endpoint?logo=pytest&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/test-coverage.json
|
||||||
|
[lgpl3]: https://www.gnu.org/licenses/lgpl-3.0
|
||||||
|
[lgpl3-img]: https://img.shields.io/badge/License-LGPL_v3.0-darkgreen.svg?logo=gnu
|
||||||
|
[pypi-latest-version-img]: https://img.shields.io/pypi/v/asyncio-taskpool?color=teal&logo=pypi
|
||||||
|
[pypi-latest-version]: https://pypi.org/project/asyncio-taskpool/
|
||||||
|
46
cloc.sh
Executable file
46
cloc.sh
Executable file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
# See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
# If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
typeset option
|
||||||
|
if getopts 'bcmp' option; then
|
||||||
|
if [[ ${option} == [bcmp] ]]; then
|
||||||
|
shift
|
||||||
|
else
|
||||||
|
echo >&2 "Invalid option '$1' provided"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
typeset source=$1
|
||||||
|
if [[ -z ${source} ]]; then
|
||||||
|
echo >&2 Source file/directory missing
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
typeset blank code comment commentpercent
|
||||||
|
read blank comment code commentpercent < <( \
|
||||||
|
cloc --csv --quiet --hide-rate --include-lang Python ${source} |
|
||||||
|
awk -F, '$2 == "SUM" {printf ("%d %d %d %1.0f", $3, $4, $5, 100 * $4 / ($5 + $4)); exit}'
|
||||||
|
)
|
||||||
|
|
||||||
|
case ${option} in
|
||||||
|
b) echo ${blank} ;;
|
||||||
|
c) echo ${code} ;;
|
||||||
|
m) echo ${comment} ;;
|
||||||
|
p) echo ${commentpercent} ;;
|
||||||
|
*) echo Blank lines: ${blank}
|
||||||
|
echo Lines of comments: ${comment}
|
||||||
|
echo Lines of code: ${code}
|
||||||
|
echo Comment percentage: ${commentpercent} ;;
|
||||||
|
esac
|
26
coverage.sh
26
coverage.sh
@ -1,3 +1,25 @@
|
|||||||
#!/usr/bin/env sh
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
coverage erase && coverage run -m unittest discover && coverage report
|
# This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
# See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
# If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
coverage erase
|
||||||
|
coverage run 2> /dev/null
|
||||||
|
|
||||||
|
typeset report=$(coverage report)
|
||||||
|
typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
|
||||||
|
|
||||||
|
if [[ ${total} == 100% ]]; then
|
||||||
|
echo ${total}
|
||||||
|
else
|
||||||
|
echo "${report}"
|
||||||
|
fi
|
||||||
|
0
docs/source/_static/.placeholder
Normal file
0
docs/source/_static/.placeholder
Normal file
@ -1,7 +0,0 @@
|
|||||||
asyncio\_taskpool.control.parser module
|
|
||||||
=======================================
|
|
||||||
|
|
||||||
.. automodule:: asyncio_taskpool.control.parser
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
@ -13,6 +13,4 @@ Submodules
|
|||||||
:maxdepth: 4
|
:maxdepth: 4
|
||||||
|
|
||||||
asyncio_taskpool.control.client
|
asyncio_taskpool.control.client
|
||||||
asyncio_taskpool.control.parser
|
|
||||||
asyncio_taskpool.control.server
|
asyncio_taskpool.control.server
|
||||||
asyncio_taskpool.control.session
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
asyncio\_taskpool.control.session module
|
|
||||||
========================================
|
|
||||||
|
|
||||||
.. automodule:: asyncio_taskpool.control.session
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
|
|||||||
author = 'Daniil Fajnberg'
|
author = 'Daniil Fajnberg'
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '1.0.0-beta'
|
release = '1.1.3'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
@ -45,6 +45,7 @@ Contents
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
pages/pool
|
pages/pool
|
||||||
|
pages/ids
|
||||||
pages/control
|
pages/control
|
||||||
api/api
|
api/api
|
||||||
|
|
||||||
|
42
docs/source/pages/ids.rst
Normal file
42
docs/source/pages/ids.rst
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
.. This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
.. Copyright © 2022 Daniil Fajnberg
|
||||||
|
|
||||||
|
|
||||||
|
IDs, groups & names
|
||||||
|
===================
|
||||||
|
|
||||||
|
Task IDs
|
||||||
|
--------
|
||||||
|
|
||||||
|
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
|
||||||
|
|
||||||
|
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
|
||||||
|
|
||||||
|
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
|
||||||
|
|
||||||
|
Task groups
|
||||||
|
-----------
|
||||||
|
|
||||||
|
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
|
||||||
|
|
||||||
|
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
|
||||||
|
|
||||||
|
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
|
||||||
|
|
||||||
|
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
|
||||||
|
|
||||||
|
Pool names
|
||||||
|
----------
|
||||||
|
|
||||||
|
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.
|
@ -87,7 +87,7 @@ By contrast, here is how you would do it with a task pool:
|
|||||||
...
|
...
|
||||||
await pool.flush()
|
await pool.flush()
|
||||||
|
|
||||||
Pretty much self-explanatory, no?
|
Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
|
||||||
|
|
||||||
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 1.0.0
|
version = 1.1.3
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -35,7 +35,7 @@ __all__ = []
|
|||||||
|
|
||||||
CLIENT_CLASS = 'client_class'
|
CLIENT_CLASS = 'client_class'
|
||||||
UNIX, TCP = 'unix', 'tcp'
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
SOCKET_PATH = 'path'
|
SOCKET_PATH = 'socket_path'
|
||||||
HOST, PORT = 'host', 'port'
|
HOST, PORT = 'host', 'port'
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,6 +85,7 @@ class ControlClient(ABC):
|
|||||||
"""
|
"""
|
||||||
self._connected = True
|
self._connected = True
|
||||||
writer.write(json.dumps(self._client_info()).encode())
|
writer.write(json.dumps(self._client_info()).encode())
|
||||||
|
writer.write(b'\n')
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
||||||
@ -97,21 +98,22 @@ class ControlClient(ABC):
|
|||||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
`None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
|
||||||
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
|
||||||
|
converted to lowercase.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
cmd = input("> ").strip().lower()
|
||||||
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
|
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
|
||||||
msg = CLIENT_EXIT
|
cmd = CLIENT_EXIT
|
||||||
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||||
print()
|
print()
|
||||||
return
|
return
|
||||||
if msg == CLIENT_EXIT:
|
if cmd == CLIENT_EXIT:
|
||||||
writer.close()
|
writer.close()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
return
|
return
|
||||||
return msg
|
return cmd or None # will be None if `cmd` is an empty string
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
"""
|
"""
|
||||||
@ -130,6 +132,7 @@ class ControlClient(ABC):
|
|||||||
try:
|
try:
|
||||||
# Send the command to the server.
|
# Send the command to the server.
|
||||||
writer.write(cmd.encode())
|
writer.write(cmd.encode())
|
||||||
|
writer.write(b'\n')
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
@ -17,19 +17,21 @@ If not, see <https://www.gnu.org/licenses/>."""
|
|||||||
__doc__ = """
|
__doc__ = """
|
||||||
Definition of the :class:`ControlParser` used in a
|
Definition of the :class:`ControlParser` used in a
|
||||||
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
|
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
|
||||||
|
|
||||||
|
It should not be considered part of the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
|
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from asyncio.streams import StreamWriter
|
|
||||||
from inspect import Parameter, getmembers, isfunction, signature
|
from inspect import Parameter, getmembers, isfunction, signature
|
||||||
|
from io import StringIO
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
|
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
|
||||||
|
|
||||||
from ..exceptions import HelpRequested, ParserError
|
from ..exceptions import HelpRequested, ParserError
|
||||||
from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER
|
from ..internals.constants import CLIENT_INFO, CMD
|
||||||
from ..internals.helpers import get_first_doc_line, resolve_dotted_path
|
from ..internals.helpers import get_first_doc_line, resolve_dotted_path
|
||||||
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
||||||
|
|
||||||
@ -52,8 +54,8 @@ class ControlParser(ArgumentParser):
|
|||||||
"""
|
"""
|
||||||
Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
|
Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
|
||||||
|
|
||||||
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
|
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
|
||||||
instance passed to it during initialization.
|
`StringIO` instance passed to it during initialization.
|
||||||
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
||||||
connected client.
|
connected client.
|
||||||
Finally, it offers some convenience methods and makes use of custom exceptions.
|
Finally, it offers some convenience methods and makes use of custom exceptions.
|
||||||
@ -87,25 +89,23 @@ class ControlParser(ArgumentParser):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
return ClientHelpFormatter
|
return ClientHelpFormatter
|
||||||
|
|
||||||
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None:
|
def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Sets some internal attributes in addition to the base class.
|
Sets some internal attributes in addition to the base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_writer:
|
stream:
|
||||||
The instance of the :class:`asyncio.StreamWriter` to use for message output.
|
A file-like I/O object to use for message output.
|
||||||
terminal_width (optional):
|
terminal_width (optional):
|
||||||
The terminal width to use for all message formatting. By default the :code:`columns` attribute from
|
The terminal width to use for all message formatting. By default the :code:`columns` attribute from
|
||||||
:func:`shutil.get_terminal_size` is taken.
|
:func:`shutil.get_terminal_size` is taken.
|
||||||
**kwargs(optional):
|
**kwargs(optional):
|
||||||
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
|
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
|
||||||
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
|
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
|
||||||
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
|
||||||
"""
|
"""
|
||||||
self._stream_writer: StreamWriter = stream_writer
|
self._stream: StringIO = stream
|
||||||
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
|
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
|
||||||
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
|
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
|
||||||
kwargs.setdefault('exit_on_error', False)
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._flags: Set[str] = set()
|
self._flags: Set[str] = set()
|
||||||
self._commands = None
|
self._commands = None
|
||||||
@ -194,7 +194,7 @@ class ControlParser(ArgumentParser):
|
|||||||
Dictionary mapping class member names to the (sub-)parsers created from them.
|
Dictionary mapping class member names to the (sub-)parsers created from them.
|
||||||
"""
|
"""
|
||||||
parsers: ParsersDict = {}
|
parsers: ParsersDict = {}
|
||||||
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
|
common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
|
||||||
for name, member in getmembers(cls):
|
for name, member in getmembers(cls):
|
||||||
if name in omit_members or (name.startswith('_') and public_only):
|
if name in omit_members or (name.startswith('_') and public_only):
|
||||||
continue
|
continue
|
||||||
@ -214,9 +214,9 @@ class ControlParser(ArgumentParser):
|
|||||||
return self._commands
|
return self._commands
|
||||||
|
|
||||||
def _print_message(self, message: str, *args, **kwargs) -> None:
|
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||||
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
|
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
|
||||||
if message:
|
if message:
|
||||||
self._stream_writer.write(message.encode())
|
self._stream.write(message)
|
||||||
|
|
||||||
def exit(self, status: int = 0, message: str = None) -> None:
|
def exit(self, status: int = 0, message: str = None) -> None:
|
||||||
"""This is overridden to prevent system exit to be invoked."""
|
"""This is overridden to prevent system exit to be invoked."""
|
||||||
|
@ -31,6 +31,7 @@ from typing import Optional, Union
|
|||||||
from .client import ControlClient, TCPControlClient, UnixControlClient
|
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||||
from .session import ControlSession
|
from .session import ControlSession
|
||||||
from ..pool import AnyTaskPoolT
|
from ..pool import AnyTaskPoolT
|
||||||
|
from ..internals.helpers import classmethod
|
||||||
from ..internals.types import ConnectedCallbackT, PathT
|
from ..internals.types import ConnectedCallbackT, PathT
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>."""
|
|||||||
|
|
||||||
__doc__ = """
|
__doc__ = """
|
||||||
Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
|
Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
|
||||||
|
|
||||||
|
It should not be considered part of the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -24,12 +26,13 @@ import json
|
|||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
from inspect import isfunction, signature
|
from inspect import isfunction, signature
|
||||||
|
from io import StringIO
|
||||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from .parser import ControlParser
|
from .parser import ControlParser
|
||||||
from ..exceptions import CommandError, HelpRequested, ParserError
|
from ..exceptions import CommandError, HelpRequested, ParserError
|
||||||
from ..pool import TaskPool, SimpleTaskPool
|
from ..pool import TaskPool, SimpleTaskPool
|
||||||
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
|
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
|
||||||
from ..internals.helpers import return_or_exception
|
from ..internals.helpers import return_or_exception
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -72,6 +75,7 @@ class ControlSession:
|
|||||||
self._reader: StreamReader = reader
|
self._reader: StreamReader = reader
|
||||||
self._writer: StreamWriter = writer
|
self._writer: StreamWriter = writer
|
||||||
self._parser: Optional[ControlParser] = None
|
self._parser: Optional[ControlParser] = None
|
||||||
|
self._response_buffer: StringIO = StringIO()
|
||||||
|
|
||||||
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
|
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
@ -130,10 +134,11 @@ class ControlSession:
|
|||||||
Client info is retrieved, server info is sent back, and the
|
Client info is retrieved, server info is sent back, and the
|
||||||
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
||||||
"""
|
"""
|
||||||
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
msg = (await self._reader.readline()).decode().strip()
|
||||||
|
client_info = json.loads(msg)
|
||||||
log.debug("%s connected", self._client_class_name)
|
log.debug("%s connected", self._client_class_name)
|
||||||
parser_kwargs = {
|
parser_kwargs = {
|
||||||
STREAM_WRITER: self._writer,
|
'stream': self._response_buffer,
|
||||||
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
|
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
|
||||||
'prog': '',
|
'prog': '',
|
||||||
'usage': f'[-h] [{CMD}] ...'
|
'usage': f'[-h] [{CMD}] ...'
|
||||||
@ -143,6 +148,7 @@ class ControlSession:
|
|||||||
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
||||||
self._parser.add_class_commands(self._pool.__class__)
|
self._parser.add_class_commands(self._pool.__class__)
|
||||||
self._writer.write(str(self._pool).encode())
|
self._writer.write(str(self._pool).encode())
|
||||||
|
self._writer.write(b'\n')
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
|
||||||
async def _parse_command(self, msg: str) -> None:
|
async def _parse_command(self, msg: str) -> None:
|
||||||
@ -160,7 +166,7 @@ class ControlSession:
|
|||||||
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
||||||
except ArgumentError as e:
|
except ArgumentError as e:
|
||||||
log.debug("%s got an ArgumentError", self._client_class_name)
|
log.debug("%s got an ArgumentError", self._client_class_name)
|
||||||
self._writer.write(str(e).encode())
|
self._response_buffer.write(str(e))
|
||||||
return
|
return
|
||||||
except (HelpRequested, ParserError):
|
except (HelpRequested, ParserError):
|
||||||
log.debug("%s received usage help", self._client_class_name)
|
log.debug("%s received usage help", self._client_class_name)
|
||||||
@ -171,7 +177,7 @@ class ControlSession:
|
|||||||
elif isinstance(command, property):
|
elif isinstance(command, property):
|
||||||
await self._exec_property_and_respond(command, **kwargs)
|
await self._exec_property_and_respond(command, **kwargs)
|
||||||
else:
|
else:
|
||||||
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode())
|
self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
|
||||||
|
|
||||||
async def listen(self) -> None:
|
async def listen(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -183,9 +189,14 @@ class ControlSession:
|
|||||||
It will obviously block indefinitely.
|
It will obviously block indefinitely.
|
||||||
"""
|
"""
|
||||||
while self._control_server.is_serving():
|
while self._control_server.is_serving():
|
||||||
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
msg = (await self._reader.readline()).decode().strip()
|
||||||
if not msg:
|
if not msg:
|
||||||
log.debug("%s disconnected", self._client_class_name)
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
break
|
break
|
||||||
await self._parse_command(msg)
|
await self._parse_command(msg)
|
||||||
|
response = self._response_buffer.getvalue()
|
||||||
|
self._response_buffer.seek(0)
|
||||||
|
self._response_buffer.truncate()
|
||||||
|
self._writer.write(response.encode())
|
||||||
|
self._writer.write(b'\n')
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
@ -21,13 +21,17 @@ This module should **not** be considered part of the public API.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
|
|
||||||
|
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
|
||||||
|
|
||||||
DEFAULT_TASK_GROUP = 'default'
|
DEFAULT_TASK_GROUP = 'default'
|
||||||
|
|
||||||
SESSION_MSG_BYTES = 1024 * 100
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
|
||||||
STREAM_WRITER = 'stream_writer'
|
|
||||||
CMD = 'command'
|
CMD = 'command'
|
||||||
CMD_OK = b"ok"
|
CMD_OK = b"ok"
|
||||||
|
|
||||||
|
@ -19,11 +19,13 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import builtins
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from inspect import getdoc
|
from inspect import getdoc
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Callable, Optional, Type, Union
|
||||||
|
|
||||||
|
from .constants import PYTHON_BEFORE_39
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
|
|
||||||
@ -131,3 +133,25 @@ def resolve_dotted_path(dotted_path: str) -> object:
|
|||||||
import_module(module_name)
|
import_module(module_name)
|
||||||
found = getattr(found, name)
|
found = getattr(found, name)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
class ClassMethodWorkaround:
|
||||||
|
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
|
||||||
|
|
||||||
|
def __init__(self, method_or_property: Union[Callable, property]) -> None:
|
||||||
|
if isinstance(method_or_property, property):
|
||||||
|
self._getter = method_or_property.fget
|
||||||
|
else:
|
||||||
|
self._getter = method_or_property
|
||||||
|
|
||||||
|
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
|
||||||
|
if obj is None:
|
||||||
|
return self._getter(cls)
|
||||||
|
return self._getter(obj)
|
||||||
|
|
||||||
|
|
||||||
|
# Starting with Python 3.9, this is thankfully no longer necessary.
|
||||||
|
if PYTHON_BEFORE_39:
|
||||||
|
classmethod = ClassMethodWorkaround
|
||||||
|
else:
|
||||||
|
classmethod = builtins.classmethod
|
||||||
|
@ -31,8 +31,8 @@ T = TypeVar('T')
|
|||||||
ArgsT = Iterable[Any]
|
ArgsT = Iterable[Any]
|
||||||
KwArgsT = Mapping[str, Any]
|
KwArgsT = Mapping[str, Any]
|
||||||
|
|
||||||
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
AnyCallableT = Callable[..., Union[T, Awaitable[T]]]
|
||||||
CoroutineFunc = Callable[[...], Coroutine]
|
CoroutineFunc = Callable[..., Coroutine]
|
||||||
|
|
||||||
EndCB = Callable
|
EndCB = Callable
|
||||||
CancelCB = Callable
|
CancelCB = Callable
|
||||||
|
@ -28,16 +28,17 @@ For further details about the classes check their respective documentation.
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from asyncio.tasks import Task, create_task, gather
|
from asyncio.tasks import Task, create_task, gather
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from math import inf
|
from math import inf
|
||||||
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from .internals.constants import DEFAULT_TASK_GROUP
|
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
|
||||||
from .internals.group_register import TaskGroupRegister
|
from .internals.group_register import TaskGroupRegister
|
||||||
from .internals.helpers import execute_optional, star_function
|
from .internals.helpers import execute_optional, star_function
|
||||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||||
@ -71,7 +72,7 @@ class BaseTaskPool:
|
|||||||
|
|
||||||
# Initialize flags; immutably set the name.
|
# Initialize flags; immutably set the name.
|
||||||
self._locked: bool = False
|
self._locked: bool = False
|
||||||
self._closed: bool = False
|
self._closed: Event = Event()
|
||||||
self._name: str = name
|
self._name: str = name
|
||||||
|
|
||||||
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
||||||
@ -220,7 +221,7 @@ class BaseTaskPool:
|
|||||||
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
||||||
if function and not iscoroutinefunction(function):
|
if function and not iscoroutinefunction(function):
|
||||||
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
||||||
if self._closed:
|
if self._closed.is_set():
|
||||||
raise exceptions.PoolIsClosed("You must use another pool")
|
raise exceptions.PoolIsClosed("You must use another pool")
|
||||||
if self._locked and not ignore_lock:
|
if self._locked and not ignore_lock:
|
||||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||||
@ -360,6 +361,23 @@ class BaseTaskPool:
|
|||||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary to unpack in a `Task.cancel()` method.
|
||||||
|
|
||||||
|
This method exists to ensure proper compatibility with older Python versions.
|
||||||
|
If `msg` is `None`, an empty dictionary is returned.
|
||||||
|
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
|
||||||
|
Otherwise the keyword dictionary contains the `msg` parameter.
|
||||||
|
"""
|
||||||
|
if msg is None:
|
||||||
|
return {}
|
||||||
|
if PYTHON_BEFORE_39:
|
||||||
|
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
|
||||||
|
return {}
|
||||||
|
return {'msg': msg}
|
||||||
|
|
||||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels the tasks with the specified IDs.
|
Cancels the tasks with the specified IDs.
|
||||||
@ -378,8 +396,9 @@ class BaseTaskPool:
|
|||||||
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
||||||
"""
|
"""
|
||||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel(msg=msg)
|
task.cancel(**kw)
|
||||||
|
|
||||||
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
||||||
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
||||||
@ -392,7 +411,7 @@ class BaseTaskPool:
|
|||||||
self._meta_tasks_cancelled.update(meta_tasks)
|
self._meta_tasks_cancelled.update(meta_tasks)
|
||||||
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
||||||
|
|
||||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
|
||||||
"""
|
"""
|
||||||
Removes all tasks from the specified group and cancels them.
|
Removes all tasks from the specified group and cancels them.
|
||||||
|
|
||||||
@ -406,7 +425,7 @@ class BaseTaskPool:
|
|||||||
self._cancel_group_meta_tasks(group_name)
|
self._cancel_group_meta_tasks(group_name)
|
||||||
while group_reg:
|
while group_reg:
|
||||||
try:
|
try:
|
||||||
self._tasks_running[group_reg.pop()].cancel(msg=msg)
|
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
continue
|
continue
|
||||||
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
||||||
@ -433,7 +452,8 @@ class BaseTaskPool:
|
|||||||
group_reg = self._task_groups.pop(group_name)
|
group_reg = self._task_groups.pop(group_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
kw = self._get_cancel_kw(msg)
|
||||||
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
log.debug("%s forgot task group %s", str(self), group_name)
|
log.debug("%s forgot task group %s", str(self), group_name)
|
||||||
|
|
||||||
def cancel_all(self, msg: str = None) -> None:
|
def cancel_all(self, msg: str = None) -> None:
|
||||||
@ -448,9 +468,10 @@ class BaseTaskPool:
|
|||||||
msg (optional): Passed to the `Task.cancel()` method of every task.
|
msg (optional): Passed to the `Task.cancel()` method of every task.
|
||||||
"""
|
"""
|
||||||
log.warning("%s cancelling all tasks!", str(self))
|
log.warning("%s cancelling all tasks!", str(self))
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
while self._task_groups:
|
while self._task_groups:
|
||||||
group_name, group_reg = self._task_groups.popitem()
|
group_name, group_reg = self._task_groups.popitem()
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
|
|
||||||
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
||||||
"""
|
"""
|
||||||
@ -529,9 +550,16 @@ class BaseTaskPool:
|
|||||||
self._tasks_ended.clear()
|
self._tasks_ended.clear()
|
||||||
self._tasks_cancelled.clear()
|
self._tasks_cancelled.clear()
|
||||||
self._tasks_running.clear()
|
self._tasks_running.clear()
|
||||||
self._closed = True
|
self._closed.set()
|
||||||
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
|
|
||||||
# await it to allow blocking until a closing command comes from a server.
|
async def until_closed(self) -> bool:
|
||||||
|
"""
|
||||||
|
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`True` once the pool is closed.
|
||||||
|
"""
|
||||||
|
return await self._closed.wait()
|
||||||
|
|
||||||
|
|
||||||
class TaskPool(BaseTaskPool):
|
class TaskPool(BaseTaskPool):
|
||||||
@ -604,7 +632,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
# This means there was probably something wrong with the function arguments.
|
# This means there was probably something wrong with the function arguments.
|
||||||
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
||||||
continue
|
continue # TODO: Consider returning instead of continuing
|
||||||
try:
|
try:
|
||||||
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
||||||
cancel_callback=cancel_callback)
|
cancel_callback=cancel_callback)
|
||||||
@ -640,7 +668,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
kwargs (optional):
|
kwargs (optional):
|
||||||
The keyword-arguments to pass into each function call.
|
The keyword-arguments to pass into each function call.
|
||||||
num (optional):
|
num (optional):
|
||||||
The number of tasks to spawn with the specified parameters.
|
The number of tasks to spawn with the specified parameters. Defaults to 1.
|
||||||
group_name (optional):
|
group_name (optional):
|
||||||
Name of the task group to add the new tasks to. By default, a unique name is constructed in the form
|
Name of the task group to add the new tasks to. By default, a unique name is constructed in the form
|
||||||
:code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an
|
:code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an
|
||||||
@ -737,9 +765,10 @@ class TaskPool(BaseTaskPool):
|
|||||||
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||||
"""
|
"""
|
||||||
Creates tasks in the pool with arguments from the supplied iterable.
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
||||||
|
The method is a task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
@ -791,10 +820,10 @@ class TaskPool(BaseTaskPool):
|
|||||||
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
|
||||||
|
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
|
||||||
|
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. The method is a task-based
|
||||||
|
equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
@ -848,6 +877,8 @@ class TaskPool(BaseTaskPool):
|
|||||||
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
|
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
||||||
as positional arguments to the function.
|
as positional arguments to the function.
|
||||||
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
||||||
@ -865,6 +896,8 @@ class TaskPool(BaseTaskPool):
|
|||||||
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
||||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
||||||
unpacked as keyword-arguments to the function.
|
unpacked as keyword-arguments to the function.
|
||||||
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
||||||
@ -941,13 +974,24 @@ class SimpleTaskPool(BaseTaskPool):
|
|||||||
|
|
||||||
async def _start_num(self, num: int, group_name: str) -> None:
|
async def _start_num(self, num: int, group_name: str) -> None:
|
||||||
"""Starts `num` new tasks in group `group_name`."""
|
"""Starts `num` new tasks in group `group_name`."""
|
||||||
start_coroutines = (
|
for i in range(num):
|
||||||
self._start_task(self._func(*self._args, **self._kwargs), group_name=group_name,
|
try:
|
||||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
coroutine = self._func(*self._args, **self._kwargs)
|
||||||
for _ in range(num)
|
except Exception as e:
|
||||||
)
|
# This means there was probably something wrong with the function arguments.
|
||||||
# TODO: Same deal as with the other meta tasks, provide proper cancellation handling!
|
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
await gather(*start_coroutines)
|
str(e.__class__.__name__), str(self), self._func.__name__,
|
||||||
|
repr(self._args), repr(self._kwargs))
|
||||||
|
continue # TODO: Consider returning instead of continuing
|
||||||
|
try:
|
||||||
|
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
|
||||||
|
cancel_callback=self._cancel_callback)
|
||||||
|
except CancelledError:
|
||||||
|
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||||
|
# more tasks and can return immediately.
|
||||||
|
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
||||||
|
coroutine.close()
|
||||||
|
return
|
||||||
|
|
||||||
def start(self, num: int) -> str:
|
def start(self, num: int) -> str:
|
||||||
"""
|
"""
|
||||||
|
30
tests/__main__.py
Normal file
30
tests/__main__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Main entry point for all unit tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_suite = unittest.defaultTestLoader.discover('.')
|
||||||
|
test_runner = unittest.TextTestRunner(resultclass=unittest.TextTestResult)
|
||||||
|
result = test_runner.run(test_suite)
|
||||||
|
sys.exit(not result.wasSuccessful())
|
@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_client = MagicMock(start=mock_client_start)
|
mock_client = MagicMock(start=mock_client_start)
|
||||||
mock_client_cls = MagicMock(return_value=mock_client)
|
mock_client_cls = MagicMock(return_value=mock_client)
|
||||||
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
|
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
|
||||||
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
|
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
|
||||||
self.assertIsNone(await module.main())
|
self.assertIsNone(await module.main())
|
||||||
mock_parse_cli.assert_called_once_with()
|
mock_parse_cli.assert_called_once_with()
|
||||||
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
|
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
|
||||||
|
@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||||
self.assertTrue(self.client._connected)
|
self.assertTrue(self.client._connected)
|
||||||
mock__client_info.assert_called_once_with()
|
mock__client_info.assert_called_once_with()
|
||||||
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_has_calls([
|
self.mock_print.assert_has_calls([
|
||||||
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||||
self.mock_drain.side_effect = err = ConnectionError()
|
self.mock_drain.side_effect = err = ConnectionError()
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_called_once_with(cmd.encode())
|
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_not_awaited()
|
self.mock_read.assert_not_awaited()
|
||||||
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||||
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.mock_print.reset_mock()
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_called_once_with(cmd.encode())
|
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_called_once_with(FOO)
|
self.mock_print.assert_called_once_with(FOO)
|
||||||
|
@ -41,9 +41,9 @@ class ControlParserTestCase(TestCase):
|
|||||||
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
|
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
|
||||||
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
||||||
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
||||||
self.stream_writer, self.terminal_width = MagicMock(), 420
|
self.stream, self.terminal_width = MagicMock(), 420
|
||||||
self.kwargs = {
|
self.kwargs = {
|
||||||
'stream_writer': self.stream_writer,
|
'stream': self.stream,
|
||||||
'terminal_width': self.terminal_width,
|
'terminal_width': self.terminal_width,
|
||||||
'formatter_class': FOO
|
'formatter_class': FOO
|
||||||
}
|
}
|
||||||
@ -72,10 +72,9 @@ class ControlParserTestCase(TestCase):
|
|||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assertIsInstance(self.parser, ArgumentParser)
|
self.assertIsInstance(self.parser, ArgumentParser)
|
||||||
self.assertEqual(self.stream_writer, self.parser._stream_writer)
|
self.assertEqual(self.stream, self.parser._stream)
|
||||||
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
||||||
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
||||||
self.assertFalse(getattr(self.parser, 'exit_on_error'))
|
|
||||||
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
||||||
self.assertSetEqual(set(), self.parser._flags)
|
self.assertSetEqual(set(), self.parser._flags)
|
||||||
self.assertIsNone(self.parser._commands)
|
self.assertIsNone(self.parser._commands)
|
||||||
@ -89,7 +88,7 @@ class ControlParserTestCase(TestCase):
|
|||||||
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
||||||
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
||||||
expected_name = 'foo-bar'
|
expected_name = 'foo-bar'
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
|
||||||
to_omit = ['abc', 'xyz']
|
to_omit = ['abc', 'xyz']
|
||||||
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
|
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
@ -107,7 +106,7 @@ class ControlParserTestCase(TestCase):
|
|||||||
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
||||||
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
||||||
expected_name = 'get-prop'
|
expected_name = 'get-prop'
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
|
||||||
output = self.parser.add_property_command(prop, **kwargs)
|
output = self.parser.add_property_command(prop, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
mock_get_first_doc_line.assert_called_once_with(get_prop)
|
mock_get_first_doc_line.assert_called_once_with(get_prop)
|
||||||
@ -119,7 +118,7 @@ class ControlParserTestCase(TestCase):
|
|||||||
|
|
||||||
prop = property(get_prop, set_prop)
|
prop = property(get_prop, set_prop)
|
||||||
expected_help = f"Get/set the `.{expected_name}` property"
|
expected_help = f"Get/set the `.{expected_name}` property"
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
|
||||||
output = self.parser.add_property_command(prop, **kwargs)
|
output = self.parser.add_property_command(prop, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
|
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
|
||||||
@ -152,8 +151,7 @@ class ControlParserTestCase(TestCase):
|
|||||||
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
|
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
|
||||||
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
|
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
|
||||||
x = 'x'
|
x = 'x'
|
||||||
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer,
|
common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
|
||||||
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
|
|
||||||
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
|
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
|
||||||
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
|
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
|
||||||
self.assertDictEqual(expected_output, output)
|
self.assertDictEqual(expected_output, output)
|
||||||
@ -170,12 +168,12 @@ class ControlParserTestCase(TestCase):
|
|||||||
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
|
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
|
||||||
|
|
||||||
def test__print_message(self):
|
def test__print_message(self):
|
||||||
self.stream_writer.write = MagicMock()
|
self.stream.write = MagicMock()
|
||||||
self.assertIsNone(self.parser._print_message(''))
|
self.assertIsNone(self.parser._print_message(''))
|
||||||
self.stream_writer.write.assert_not_called()
|
self.stream.write.assert_not_called()
|
||||||
msg = 'foo bar baz'
|
msg = 'foo bar baz'
|
||||||
self.assertIsNone(self.parser._print_message(msg))
|
self.assertIsNone(self.parser._print_message(msg))
|
||||||
self.stream_writer.write.assert_called_once_with(msg.encode())
|
self.stream.write.assert_called_once_with(msg)
|
||||||
|
|
||||||
@patch.object(parser.ControlParser, '_print_message')
|
@patch.object(parser.ControlParser, '_print_message')
|
||||||
def test_exit(self, mock__print_message: MagicMock):
|
def test_exit(self, mock__print_message: MagicMock):
|
||||||
|
@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from argparse import ArgumentError, Namespace
|
from argparse import ArgumentError, Namespace
|
||||||
|
from io import StringIO
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||||
|
|
||||||
from asyncio_taskpool.control import session
|
from asyncio_taskpool.control import session
|
||||||
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
|
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
|
||||||
from asyncio_taskpool.exceptions import HelpRequested
|
from asyncio_taskpool.exceptions import HelpRequested
|
||||||
from asyncio_taskpool.pool import SimpleTaskPool
|
from asyncio_taskpool.pool import SimpleTaskPool
|
||||||
|
|
||||||
@ -61,14 +62,15 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(self.mock_reader, self.session._reader)
|
self.assertEqual(self.mock_reader, self.session._reader)
|
||||||
self.assertEqual(self.mock_writer, self.session._writer)
|
self.assertEqual(self.mock_writer, self.session._writer)
|
||||||
self.assertIsNone(self.session._parser)
|
self.assertIsNone(self.session._parser)
|
||||||
|
self.assertIsInstance(self.session._response_buffer, StringIO)
|
||||||
|
|
||||||
@patch.object(session, 'return_or_exception')
|
@patch.object(session, 'return_or_exception')
|
||||||
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
|
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
|
||||||
def method(self, arg1, arg2, *var_args, **rest): pass
|
def method(self, arg1, arg2, *var_args, **rest): pass
|
||||||
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
|
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
|
||||||
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest
|
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
|
||||||
mock_return_or_exception.return_value = None
|
mock_return_or_exception.return_value = None
|
||||||
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs))
|
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
|
||||||
mock_return_or_exception.assert_awaited_once_with(
|
mock_return_or_exception.assert_awaited_once_with(
|
||||||
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
||||||
)
|
)
|
||||||
@ -100,11 +102,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parser_cls.return_value = mock_parser
|
mock_parser_cls.return_value = mock_parser
|
||||||
width = 5678
|
width = 5678
|
||||||
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
||||||
mock_read = AsyncMock(return_value=msg.encode())
|
mock_readline = AsyncMock(return_value=msg.encode())
|
||||||
self.mock_reader.read = mock_read
|
self.mock_reader.readline = mock_readline
|
||||||
self.mock_writer.drain = AsyncMock()
|
self.mock_writer.drain = AsyncMock()
|
||||||
expected_parser_kwargs = {
|
expected_parser_kwargs = {
|
||||||
STREAM_WRITER: self.mock_writer,
|
'stream': self.session._response_buffer,
|
||||||
CLIENT_INFO.TERMINAL_WIDTH: width,
|
CLIENT_INFO.TERMINAL_WIDTH: width,
|
||||||
'prog': '',
|
'prog': '',
|
||||||
'usage': f'[-h] [{CMD}] ...'
|
'usage': f'[-h] [{CMD}] ...'
|
||||||
@ -115,11 +117,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
}
|
}
|
||||||
self.assertIsNone(await self.session.client_handshake())
|
self.assertIsNone(await self.session.client_handshake())
|
||||||
self.assertEqual(mock_parser, self.session._parser)
|
self.assertEqual(mock_parser, self.session._parser)
|
||||||
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
mock_readline.assert_awaited_once_with()
|
||||||
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
||||||
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
||||||
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
|
self.mock_writer.write.assert_has_calls([call(str(self.mock_pool).encode()), call(b'\n')])
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
||||||
@ -132,10 +134,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
kwargs = {FOO: BAR, 'hello': 'python'}
|
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||||
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
|
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
|
||||||
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
||||||
self.mock_writer.write = MagicMock()
|
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
|
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
|
||||||
mock__exec_property_and_respond.assert_not_called()
|
mock__exec_property_and_respond.assert_not_called()
|
||||||
|
|
||||||
@ -145,7 +146,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
|
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_called()
|
mock__exec_method_and_respond.assert_not_called()
|
||||||
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
|
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
|
||||||
|
|
||||||
@ -161,47 +162,55 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
mock__exec_method_and_respond.assert_not_called()
|
mock__exec_method_and_respond.assert_not_called()
|
||||||
mock__exec_property_and_respond.assert_not_called()
|
mock__exec_property_and_respond.assert_not_called()
|
||||||
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
|
||||||
|
|
||||||
mock__exec_property_and_respond.reset_mock()
|
mock__exec_property_and_respond.reset_mock()
|
||||||
mock_parse_args.reset_mock()
|
mock_parse_args.reset_mock()
|
||||||
self.mock_writer.write.reset_mock()
|
self.session._response_buffer.seek(0)
|
||||||
|
self.session._response_buffer.truncate()
|
||||||
|
|
||||||
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_awaited()
|
mock__exec_method_and_respond.assert_not_awaited()
|
||||||
mock__exec_property_and_respond.assert_not_awaited()
|
mock__exec_property_and_respond.assert_not_awaited()
|
||||||
|
|
||||||
self.mock_writer.write.reset_mock()
|
|
||||||
mock_parse_args.reset_mock()
|
mock_parse_args.reset_mock()
|
||||||
|
self.session._response_buffer.seek(0)
|
||||||
|
self.session._response_buffer.truncate()
|
||||||
|
|
||||||
mock_parse_args.side_effect = HelpRequested()
|
mock_parse_args.side_effect = HelpRequested()
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_awaited()
|
mock__exec_method_and_respond.assert_not_awaited()
|
||||||
mock__exec_property_and_respond.assert_not_awaited()
|
mock__exec_property_and_respond.assert_not_awaited()
|
||||||
|
|
||||||
@patch.object(session.ControlSession, '_parse_command')
|
@patch.object(session.ControlSession, '_parse_command')
|
||||||
async def test_listen(self, mock__parse_command: AsyncMock):
|
async def test_listen(self, mock__parse_command: AsyncMock):
|
||||||
def make_reader_return_empty():
|
def make_reader_return_empty():
|
||||||
self.mock_reader.read.return_value = b''
|
self.mock_reader.readline.return_value = b''
|
||||||
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
||||||
msg = "fascinating"
|
msg = "fascinating"
|
||||||
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
|
self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
|
||||||
|
response = FOO + BAR + FOO
|
||||||
|
self.session._response_buffer.write(response)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
|
self.mock_reader.readline.assert_has_awaits([call(), call()])
|
||||||
mock__parse_command.assert_awaited_once_with(msg)
|
mock__parse_command.assert_awaited_once_with(msg)
|
||||||
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
|
self.mock_writer.write.assert_has_calls([call(response.encode()), call(b'\n')])
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
self.mock_reader.read.reset_mock()
|
self.mock_reader.readline.reset_mock()
|
||||||
mock__parse_command.reset_mock()
|
mock__parse_command.reset_mock()
|
||||||
|
self.mock_writer.write.reset_mock()
|
||||||
self.mock_writer.drain.reset_mock()
|
self.mock_writer.drain.reset_mock()
|
||||||
|
|
||||||
self.mock_server.is_serving = MagicMock(return_value=False)
|
self.mock_server.is_serving = MagicMock(return_value=False)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.read.assert_not_awaited()
|
self.mock_reader.readline.assert_not_awaited()
|
||||||
mock__parse_command.assert_not_awaited()
|
mock__parse_command.assert_not_awaited()
|
||||||
|
self.mock_writer.write.assert_not_called()
|
||||||
self.mock_writer.drain.assert_not_awaited()
|
self.mock_writer.drain.assert_not_awaited()
|
||||||
|
@ -18,10 +18,11 @@ __doc__ = """
|
|||||||
Unittests for the `asyncio_taskpool.helpers` module.
|
Unittests for the `asyncio_taskpool.helpers` module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase, TestCase
|
||||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool.internals import constants
|
||||||
from asyncio_taskpool.internals import helpers
|
from asyncio_taskpool.internals import helpers
|
||||||
|
|
||||||
|
|
||||||
@ -122,3 +123,45 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
|||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
helpers.resolve_dotted_path('foo.bar.baz')
|
helpers.resolve_dotted_path('foo.bar.baz')
|
||||||
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
|
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassMethodWorkaroundTestCase(TestCase):
|
||||||
|
def test_init(self):
|
||||||
|
def func(): return 'foo'
|
||||||
|
def getter(): return 'bar'
|
||||||
|
prop = property(getter)
|
||||||
|
instance = helpers.ClassMethodWorkaround(func)
|
||||||
|
self.assertIs(func, instance._getter)
|
||||||
|
instance = helpers.ClassMethodWorkaround(prop)
|
||||||
|
self.assertIs(getter, instance._getter)
|
||||||
|
|
||||||
|
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
|
||||||
|
def test_get(self, _mock_init: MagicMock):
|
||||||
|
def func(x: MagicMock): return x.__name__
|
||||||
|
instance = helpers.ClassMethodWorkaround(MagicMock())
|
||||||
|
instance._getter = func
|
||||||
|
obj, cls = None, MagicMock
|
||||||
|
expected_output = 'MagicMock'
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
obj = MagicMock(__name__='bar')
|
||||||
|
expected_output = 'bar'
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
cls = None
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
def test_correct_class(self):
|
||||||
|
is_older_python = constants.PYTHON_BEFORE_39
|
||||||
|
try:
|
||||||
|
constants.PYTHON_BEFORE_39 = True
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
|
||||||
|
constants.PYTHON_BEFORE_39 = False
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(classmethod, helpers.classmethod)
|
||||||
|
finally:
|
||||||
|
constants.PYTHON_BEFORE_39 = is_older_python
|
||||||
|
@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||||
from typing import Type
|
from typing import Type
|
||||||
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(0, self.task_pool._num_started)
|
self.assertEqual(0, self.task_pool._num_started)
|
||||||
|
|
||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
self.assertFalse(self.task_pool._closed)
|
self.assertIsInstance(self.task_pool._closed, Event)
|
||||||
|
self.assertFalse(self.task_pool._closed.is_set())
|
||||||
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
||||||
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool.get_group_ids(group_name, 'something else')
|
self.task_pool.get_group_ids(group_name, 'something else')
|
||||||
|
|
||||||
async def test__check_start(self):
|
async def test__check_start(self):
|
||||||
self.task_pool._closed = True
|
self.task_pool._closed.set()
|
||||||
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
||||||
try:
|
try:
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
||||||
with self.assertRaises(exceptions.PoolIsClosed):
|
with self.assertRaises(exceptions.PoolIsClosed):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
||||||
self.task_pool._closed = False
|
self.task_pool._closed.clear()
|
||||||
self.task_pool._locked = True
|
self.task_pool._locked = True
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
with self.assertRaises(exceptions.PoolIsLocked):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
||||||
@ -311,13 +312,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_not_called()
|
mock__task_name.assert_not_called()
|
||||||
|
|
||||||
|
@patch('warnings.warn')
|
||||||
|
def test__get_cancel_kw(self, mock_warn: MagicMock):
|
||||||
|
msg = None
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
msg = 'something'
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_called_once()
|
||||||
|
mock_warn.reset_mock()
|
||||||
|
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
|
||||||
|
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||||
def test_cancel(self, mock__get_running_task: MagicMock):
|
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
mock__get_cancel_kw.assert_called_once_with(FOO)
|
||||||
|
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
|
||||||
|
|
||||||
def test__cancel_group_meta_tasks(self):
|
def test__cancel_group_meta_tasks(self):
|
||||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||||
@ -336,6 +356,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
||||||
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
||||||
|
kw = {BAR: 10, BAZ: 20}
|
||||||
task_id = 555
|
task_id = 555
|
||||||
mock_cancel = MagicMock()
|
mock_cancel = MagicMock()
|
||||||
|
|
||||||
@ -347,27 +368,33 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
class MockRegister(set, MagicMock):
|
class MockRegister(set, MagicMock):
|
||||||
pass
|
pass
|
||||||
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
|
||||||
mock_cancel.assert_called_once_with(msg=FOO)
|
mock_cancel.assert_called_once_with(**kw)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
||||||
with self.assertRaises(exceptions.InvalidGroupName):
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
self.task_pool.cancel_group(BAR)
|
self.task_pool.cancel_group(BAR)
|
||||||
mock__cancel_and_remove_all_from_group.assert_not_called()
|
mock__cancel_and_remove_all_from_group.assert_not_called()
|
||||||
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
|
mock__get_cancel_kw.assert_called_once_with(BAR)
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
mock_group_reg = MagicMock()
|
mock_group_reg = MagicMock()
|
||||||
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
||||||
self.assertIsNone(self.task_pool.cancel_all('msg'))
|
self.assertIsNone(self.task_pool.cancel_all(BAZ))
|
||||||
|
mock__get_cancel_kw.assert_called_once_with(BAZ)
|
||||||
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
||||||
call(BAR, mock_group_reg, msg='msg'),
|
call(BAR, mock_group_reg, **fake_cancel_kw),
|
||||||
call(FOO, mock_group_reg, msg='msg')
|
call(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
])
|
])
|
||||||
|
|
||||||
def test__pop_ended_meta_tasks(self):
|
def test__pop_ended_meta_tasks(self):
|
||||||
@ -435,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
self.assertTrue(self.task_pool._closed)
|
self.assertTrue(self.task_pool._closed.is_set())
|
||||||
|
|
||||||
|
async def test_until_closed(self):
|
||||||
|
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
|
||||||
|
output = await self.task_pool.until_closed()
|
||||||
|
self.assertEqual(FOO, output)
|
||||||
|
self.task_pool._closed.wait.assert_awaited_once_with()
|
||||||
|
|
||||||
|
|
||||||
class TaskPoolTestCase(CommonTestCase):
|
class TaskPoolTestCase(CommonTestCase):
|
||||||
@ -729,13 +762,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
TEST_POOL_CANCEL_CB = MagicMock()
|
TEST_POOL_CANCEL_CB = MagicMock()
|
||||||
|
|
||||||
def get_task_pool_init_params(self) -> dict:
|
def get_task_pool_init_params(self) -> dict:
|
||||||
return super().get_task_pool_init_params() | {
|
params = super().get_task_pool_init_params()
|
||||||
|
params.update({
|
||||||
'func': self.TEST_POOL_FUNC,
|
'func': self.TEST_POOL_FUNC,
|
||||||
'args': self.TEST_POOL_ARGS,
|
'args': self.TEST_POOL_ARGS,
|
||||||
'kwargs': self.TEST_POOL_KWARGS,
|
'kwargs': self.TEST_POOL_KWARGS,
|
||||||
'end_callback': self.TEST_POOL_END_CB,
|
'end_callback': self.TEST_POOL_END_CB,
|
||||||
'cancel_callback': self.TEST_POOL_CANCEL_CB,
|
'cancel_callback': self.TEST_POOL_CANCEL_CB,
|
||||||
}
|
})
|
||||||
|
return params
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
||||||
@ -762,18 +797,30 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_task')
|
@patch.object(pool.SimpleTaskPool, '_start_task')
|
||||||
async def test__start_num(self, mock__start_task: AsyncMock):
|
async def test__start_num(self, mock__start_task: AsyncMock):
|
||||||
fake_coroutine = object()
|
|
||||||
self.task_pool._func = MagicMock(return_value=fake_coroutine)
|
|
||||||
num = 3
|
|
||||||
group_name = FOO + BAR + 'abc'
|
group_name = FOO + BAR + 'abc'
|
||||||
|
mock_awaitable1, mock_awaitable2 = object(), object()
|
||||||
|
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
||||||
|
num = 3
|
||||||
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
self.task_pool._func.assert_has_calls(num * [
|
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
call(*self.task_pool._args, **self.task_pool._kwargs)
|
call_kw = {
|
||||||
])
|
'group_name': group_name,
|
||||||
mock__start_task.assert_has_awaits(num * [
|
'end_callback': self.task_pool._end_callback,
|
||||||
call(fake_coroutine, group_name=group_name, end_callback=self.task_pool._end_callback,
|
'cancel_callback': self.task_pool._cancel_callback
|
||||||
cancel_callback=self.task_pool._cancel_callback)
|
}
|
||||||
])
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
|
||||||
|
|
||||||
|
self.task_pool._func.reset_mock(side_effect=True)
|
||||||
|
mock__start_task.reset_mock()
|
||||||
|
|
||||||
|
# Simulate cancellation while the second task is being started.
|
||||||
|
mock__start_task.side_effect = [None, CancelledError, None]
|
||||||
|
mock_coroutine_to_close = MagicMock()
|
||||||
|
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
||||||
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
|
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
|
||||||
|
mock_coroutine_to_close.close.assert_called_once_with()
|
||||||
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
||||||
|
Reference in New Issue
Block a user