generated from daniil-berg/boilerplate-py
Compare commits
32 Commits
v1.0.0-bet
...
master
Author | SHA1 | Date | |
---|---|---|---|
73aa93a9b7 | |||
051d0cb911 | |||
ee0b8c0002 | |||
28c997e0ee | |||
5a72a6d1d1 | |||
72e380cd77 | |||
85672bddeb | |||
dae883446a | |||
a4ecf39157 | |||
e3bbb05eac | |||
36527ccffc | |||
d047b99119 | |||
d7cd16c540 | |||
db306a1a1f | |||
3a8fcb2d5a | |||
cf02206588 | |||
0796038dcd | |||
f4e33baf82 | |||
9b838b6130 | |||
0daed04167 | |||
80fc91ec47 | |||
a72a7cc516 | |||
91d546ebc2 | |||
5b3ac52bf6 | |||
82e6ca7b1a | |||
153127e028 | |||
17539e9c27 | |||
1beb9fc9b0 | |||
23a4cb028a | |||
54e5bfa8a0 | |||
0e7e92a91b | |||
a9011076c4 |
@ -1,14 +1,12 @@
|
|||||||
[run]
|
[run]
|
||||||
source = src/
|
source = src/
|
||||||
branch = true
|
branch = true
|
||||||
omit =
|
command_line = -m unittest discover
|
||||||
.venv/*
|
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
|
fail_under = 100
|
||||||
show_missing = True
|
show_missing = True
|
||||||
skip_covered = False
|
skip_covered = False
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if __name__ == ['"]__main__['"]:
|
if __name__ == ['"]__main__['"]:
|
||||||
omit =
|
|
||||||
tests/*
|
|
||||||
|
88
.github/workflows/main.yaml
vendored
Normal file
88
.github/workflows/main.yaml
vendored
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
name: CI
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master]
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
name: Python ${{ matrix.python-version }} Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- '3.8'
|
||||||
|
- '3.9'
|
||||||
|
- '3.10'
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: 'requirements/dev.txt'
|
||||||
|
|
||||||
|
- name: Upgrade packaging tools
|
||||||
|
run: pip install -U pip
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -U -r requirements/dev.txt
|
||||||
|
|
||||||
|
- name: Install asyncio-taskpool
|
||||||
|
run: pip install -e .
|
||||||
|
|
||||||
|
- name: Run tests for Python ${{ matrix.python-version }}
|
||||||
|
if: ${{ matrix.python-version != '3.10' }}
|
||||||
|
run: python -m tests
|
||||||
|
|
||||||
|
- name: Run tests for Python 3.10 and save coverage
|
||||||
|
if: ${{ matrix.python-version == '3.10' }}
|
||||||
|
run: echo "coverage=$(./coverage.sh)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
coverage: ${{ env.coverage }}
|
||||||
|
|
||||||
|
update_badges:
|
||||||
|
needs: tests
|
||||||
|
name: Update Badges
|
||||||
|
env:
|
||||||
|
meta_gist_id: 3f8240a976e8781a765d9c74a583dcda
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Download `cloc`
|
||||||
|
run: sudo apt-get update -y && sudo apt-get install -y cloc
|
||||||
|
|
||||||
|
- name: Count lines of code/comments
|
||||||
|
run: |
|
||||||
|
echo "cloc_code=$(./cloc.sh -c src/)" >> $GITHUB_ENV
|
||||||
|
echo "cloc_comments=$(./cloc.sh -m src/)" >> $GITHUB_ENV
|
||||||
|
echo "cloc_commentpercent=$(./cloc.sh -p src/)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Create badge for lines of code
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: cloc-code.json
|
||||||
|
label: Lines of Code
|
||||||
|
message: ${{ env.cloc_code }}
|
||||||
|
|
||||||
|
- name: Create badge for lines of comments
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: cloc-comments.json
|
||||||
|
label: Comments
|
||||||
|
message: ${{ env.cloc_comments }} (${{ env.cloc_commentpercent }}%)
|
||||||
|
|
||||||
|
- name: Create badge for test coverage
|
||||||
|
uses: Schneegans/dynamic-badges-action@v1.2.0
|
||||||
|
with:
|
||||||
|
auth: ${{ secrets.GIST_META_DATA }}
|
||||||
|
gistID: ${{ env.meta_gist_id }}
|
||||||
|
filename: test-coverage.json
|
||||||
|
label: Coverage
|
||||||
|
message: ${{ needs.tests.outputs.coverage }}
|
11
.readthedocs.yaml
Normal file
11
.readthedocs.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
version: 2
|
||||||
|
build:
|
||||||
|
os: 'ubuntu-20.04'
|
||||||
|
tools:
|
||||||
|
python: '3.8'
|
||||||
|
python:
|
||||||
|
install:
|
||||||
|
- method: pip
|
||||||
|
path: .
|
||||||
|
sphinx:
|
||||||
|
fail_on_warning: true
|
47
README.md
47
README.md
@ -1,7 +1,30 @@
|
|||||||
|
[//]: # (This file is part of asyncio-taskpool.)
|
||||||
|
|
||||||
|
[//]: # (asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of)
|
||||||
|
[//]: # (version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.)
|
||||||
|
|
||||||
|
[//]: # (asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;)
|
||||||
|
[//]: # (without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.)
|
||||||
|
[//]: # (See the GNU Lesser General Public License for more details.)
|
||||||
|
|
||||||
|
[//]: # (You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.)
|
||||||
|
[//]: # (If not, see <https://www.gnu.org/licenses/>.)
|
||||||
|
|
||||||
# asyncio-taskpool
|
# asyncio-taskpool
|
||||||
|
|
||||||
|
[![GitHub last commit][github-last-commit-img]][github-last-commit]
|
||||||
|
![Lines of code][gist-cloc-code-img]
|
||||||
|
![Lines of comments][gist-cloc-comments-img]
|
||||||
|
![Test coverage][gist-test-coverage-img]
|
||||||
|
[![License: LGPL v3.0][lgpl3-img]][lgpl3]
|
||||||
|
[![PyPI version][pypi-latest-version-img]][pypi-latest-version]
|
||||||
|
|
||||||
**Dynamically manage pools of asyncio tasks**
|
**Dynamically manage pools of asyncio tasks**
|
||||||
|
|
||||||
|
Full documentation available at [RtD](https://asyncio-taskpool.readthedocs.io/en/latest).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
- [Contents](#contents)
|
- [Contents](#contents)
|
||||||
- [Summary](#summary)
|
- [Summary](#summary)
|
||||||
@ -27,25 +50,16 @@ Generally speaking, a task is added to a pool by providing it with a coroutine f
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from asyncio_taskpool import SimpleTaskPool
|
from asyncio_taskpool import SimpleTaskPool
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
async def work(_foo, _bar): ...
|
async def work(_foo, _bar): ...
|
||||||
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
pool = SimpleTaskPool(work, args=('xyz', 420))
|
pool = SimpleTaskPool(work, args=('xyz', 420))
|
||||||
await pool.start(5)
|
pool.start(5)
|
||||||
...
|
...
|
||||||
pool.stop(3)
|
pool.stop(3)
|
||||||
...
|
...
|
||||||
pool.lock()
|
|
||||||
await pool.gather_and_close()
|
await pool.gather_and_close()
|
||||||
...
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
|
Since one of the main goals of `asyncio-taskpool` is to be able to start/stop tasks dynamically or "on-the-fly", _most_ of the associated methods are non-blocking _most_ of the time. A notable exception is the `gather_and_close` method for awaiting the return of all tasks in the pool. (It is essentially a glorified wrapper around the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.)
|
||||||
@ -64,8 +78,7 @@ Python Version 3.8+, tested on Linux
|
|||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
Install `asyncio-taskpool[dev]` dependencies or just manually install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`.
|
Install [`coverage`](https://coverage.readthedocs.io/en/latest/) with `pip`, then execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and save the coverage report.
|
||||||
Execute the [`./coverage.sh`](coverage.sh) shell script to run all unit tests and receive the coverage report.
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
@ -76,3 +89,13 @@ The full license texts for the [GNU GPLv3.0](COPYING) and the [GNU LGPLv3.0](COP
|
|||||||
---
|
---
|
||||||
|
|
||||||
© 2022 Daniil Fajnberg
|
© 2022 Daniil Fajnberg
|
||||||
|
|
||||||
|
[github-last-commit]: https://github.com/daniil-berg/asyncio-taskpool/commits
|
||||||
|
[github-last-commit-img]: https://img.shields.io/github/last-commit/daniil-berg/asyncio-taskpool?label=Last%20commit&logo=git&
|
||||||
|
[gist-cloc-code-img]: https://img.shields.io/endpoint?logo=python&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-code.json
|
||||||
|
[gist-cloc-comments-img]: https://img.shields.io/endpoint?logo=sharp&color=lightgrey&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/cloc-comments.json
|
||||||
|
[gist-test-coverage-img]: https://img.shields.io/endpoint?logo=pytest&color=blue&url=https://gist.githubusercontent.com/daniil-berg/3f8240a976e8781a765d9c74a583dcda/raw/test-coverage.json
|
||||||
|
[lgpl3]: https://www.gnu.org/licenses/lgpl-3.0
|
||||||
|
[lgpl3-img]: https://img.shields.io/badge/License-LGPL_v3.0-darkgreen.svg?logo=gnu
|
||||||
|
[pypi-latest-version-img]: https://img.shields.io/pypi/v/asyncio-taskpool?color=teal&logo=pypi
|
||||||
|
[pypi-latest-version]: https://pypi.org/project/asyncio-taskpool/
|
||||||
|
46
cloc.sh
Executable file
46
cloc.sh
Executable file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
# See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
# If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
typeset option
|
||||||
|
if getopts 'bcmp' option; then
|
||||||
|
if [[ ${option} == [bcmp] ]]; then
|
||||||
|
shift
|
||||||
|
else
|
||||||
|
echo >&2 "Invalid option '$1' provided"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
typeset source=$1
|
||||||
|
if [[ -z ${source} ]]; then
|
||||||
|
echo >&2 Source file/directory missing
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
typeset blank code comment commentpercent
|
||||||
|
read blank comment code commentpercent < <( \
|
||||||
|
cloc --csv --quiet --hide-rate --include-lang Python ${source} |
|
||||||
|
awk -F, '$2 == "SUM" {printf ("%d %d %d %1.0f", $3, $4, $5, 100 * $4 / ($5 + $4)); exit}'
|
||||||
|
)
|
||||||
|
|
||||||
|
case ${option} in
|
||||||
|
b) echo ${blank} ;;
|
||||||
|
c) echo ${code} ;;
|
||||||
|
m) echo ${comment} ;;
|
||||||
|
p) echo ${commentpercent} ;;
|
||||||
|
*) echo Blank lines: ${blank}
|
||||||
|
echo Lines of comments: ${comment}
|
||||||
|
echo Lines of code: ${code}
|
||||||
|
echo Comment percentage: ${commentpercent} ;;
|
||||||
|
esac
|
26
coverage.sh
26
coverage.sh
@ -1,3 +1,25 @@
|
|||||||
#!/usr/bin/env sh
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
coverage erase && coverage run -m unittest discover && coverage report
|
# This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
# asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
# version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
# asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
# See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
# If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
coverage erase
|
||||||
|
coverage run 2> /dev/null
|
||||||
|
|
||||||
|
typeset report=$(coverage report)
|
||||||
|
typeset total=$(echo "${report}" | awk '$1 == "TOTAL" {print $NF; exit}')
|
||||||
|
|
||||||
|
if [[ ${total} == 100% ]]; then
|
||||||
|
echo ${total}
|
||||||
|
else
|
||||||
|
echo "${report}"
|
||||||
|
fi
|
||||||
|
0
docs/source/_static/.placeholder
Normal file
0
docs/source/_static/.placeholder
Normal file
@ -1,7 +0,0 @@
|
|||||||
asyncio\_taskpool.control.parser module
|
|
||||||
=======================================
|
|
||||||
|
|
||||||
.. automodule:: asyncio_taskpool.control.parser
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
@ -13,6 +13,4 @@ Submodules
|
|||||||
:maxdepth: 4
|
:maxdepth: 4
|
||||||
|
|
||||||
asyncio_taskpool.control.client
|
asyncio_taskpool.control.client
|
||||||
asyncio_taskpool.control.parser
|
|
||||||
asyncio_taskpool.control.server
|
asyncio_taskpool.control.server
|
||||||
asyncio_taskpool.control.session
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
asyncio\_taskpool.control.session module
|
|
||||||
========================================
|
|
||||||
|
|
||||||
.. automodule:: asyncio_taskpool.control.session
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
|
|||||||
author = 'Daniil Fajnberg'
|
author = 'Daniil Fajnberg'
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '1.0.0-beta'
|
release = '1.1.4'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
@ -45,6 +45,7 @@ Contents
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|
||||||
pages/pool
|
pages/pool
|
||||||
|
pages/ids
|
||||||
pages/control
|
pages/control
|
||||||
api/api
|
api/api
|
||||||
|
|
||||||
|
@ -96,9 +96,9 @@ When you are dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.T
|
|||||||
|
|
||||||
.. code-block:: none
|
.. code-block:: none
|
||||||
|
|
||||||
> map mypackage.mymodule.worker ['x','y','z'] -g 3
|
> map mypackage.mymodule.worker ['x','y','z'] -n 3
|
||||||
|
|
||||||
The :code:`-g` is a shorthand for :code:`--group-size` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
|
The :code:`-n` is a shorthand for :code:`--num-concurrent` in this case. In general, all (public) pool methods will have a corresponding command in the control session.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
42
docs/source/pages/ids.rst
Normal file
42
docs/source/pages/ids.rst
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
.. This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
.. asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
.. asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
.. You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
.. Copyright © 2022 Daniil Fajnberg
|
||||||
|
|
||||||
|
|
||||||
|
IDs, groups & names
|
||||||
|
===================
|
||||||
|
|
||||||
|
Task IDs
|
||||||
|
--------
|
||||||
|
|
||||||
|
Every task spawned within a pool receives an ID, which is an integer greater or equal to 0 that is unique **within that task pool instance**. An internal counter is incremented whenever a new task is spawned. A task with ID :code:`n` was the :code:`(n+1)`-th task to be spawned in the pool. Task IDs can be used to cancel specific tasks using the :py:meth:`.cancel() <asyncio_taskpool.pool.BaseTaskPool.cancel>` method.
|
||||||
|
|
||||||
|
In practice, it should rarely be necessary to target *specific* tasks. When dealing with a regular :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` instance, you would typically cancel entire task groups (see below) rather than individual tasks, whereas with :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instances you would indiscriminately cancel a number of tasks using the :py:meth:`.stop() <asyncio_taskpool.pool.SimpleTaskPool.stop>` method.
|
||||||
|
|
||||||
|
The ID of a pool task also appears in the task's name, which is set upon spawning it. (See `here <https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.set_name>`_ for the associated method of the :code:`Task` class.)
|
||||||
|
|
||||||
|
Task groups
|
||||||
|
-----------
|
||||||
|
|
||||||
|
Every method of spawning new tasks in a task pool will add them to a **task group** and return the name of that group. With :py:class:`TaskPool <asyncio_taskpool.pool.TaskPool>` methods such as :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>`, the group name can be set explicitly via the :code:`group_name` parameter. By default, the name will be a string containing some meta information depending on which method is used. Passing an existing task group name in any of those methods will result in a :py:class:`InvalidGroupName <asyncio_taskpool.exceptions.InvalidGroupName>` error.
|
||||||
|
|
||||||
|
You can cancel entire task groups using the :py:meth:`.cancel_group() <asyncio_taskpool.pool.BaseTaskPool.cancel_group>` method by passing it the group name. To check which tasks belong to a group, the :py:meth:`.get_group_ids() <asyncio_taskpool.pool.BaseTaskPool.get_group_ids>` method can be used, which takes group names and returns the IDs of the tasks belonging to them.
|
||||||
|
|
||||||
|
The :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` method will create a new group as well, each time it is called, but it does not allow customizing the group name. Typically, it will not be necessary to keep track of groups in a :py:class:`SimpleTaskPool <asyncio_taskpool.pool.SimpleTaskPool>` instance.
|
||||||
|
|
||||||
|
Task groups do not impose limits on the number of tasks in them, although they can be indirectly constrained by pool size limits.
|
||||||
|
|
||||||
|
Pool names
|
||||||
|
----------
|
||||||
|
|
||||||
|
When initializing a task pool, you can provide a custom name for it, which will appear in its string representation, e.g. when using it in a :code:`print()`. A class attribute keeps track of initialized task pools and assigns each one an index (similar to IDs for pool tasks). If no name is specified when creating a new pool, its index is used in the string representation of it. Pool names can be helpful when using multiple pools and analyzing log messages.
|
@ -46,7 +46,7 @@ Let's take a look at an example. Say you have a coroutine function that takes tw
|
|||||||
async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None:
|
async def queue_worker_function(in_queue: Queue, out_queue: Queue) -> None:
|
||||||
while True:
|
while True:
|
||||||
item = await in_queue.get()
|
item = await in_queue.get()
|
||||||
... # Do some work on the item amd arrive at a result.
|
... # Do some work on the item and arrive at a result.
|
||||||
await out_queue.put(result)
|
await out_queue.put(result)
|
||||||
|
|
||||||
How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this:
|
How would we go about concurrently executing this function, say 5 times? There are (as always) a number of ways to do this with :code:`asyncio`. If we want to use tasks and be clean about it, we can do it like this:
|
||||||
@ -81,13 +81,13 @@ By contrast, here is how you would do it with a task pool:
|
|||||||
|
|
||||||
...
|
...
|
||||||
pool = TaskPool()
|
pool = TaskPool()
|
||||||
group_name = await pool.apply(queue_worker_function, args=(q_in, q_out), num=5)
|
group_name = pool.apply(queue_worker_function, args=(q_in, q_out), num=5)
|
||||||
...
|
...
|
||||||
pool.cancel_group(group_name)
|
pool.cancel_group(group_name)
|
||||||
...
|
...
|
||||||
await pool.flush()
|
await pool.flush()
|
||||||
|
|
||||||
Pretty much self-explanatory, no?
|
Pretty much self-explanatory, no? (See :doc:`here <./ids>` for more information about groups/names).
|
||||||
|
|
||||||
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
Let's consider a slightly more involved example. Assume you have a coroutine function that takes just one argument (some data) as input, does some work with it (maybe connects to the internet in the process), and eventually writes its results to a database (which is globally defined). Here is how that might look:
|
||||||
|
|
||||||
@ -141,16 +141,17 @@ Or we could use a task pool:
|
|||||||
async def main():
|
async def main():
|
||||||
...
|
...
|
||||||
pool = TaskPool()
|
pool = TaskPool()
|
||||||
await pool.map(another_worker_function, data_iterator, group_size=5)
|
pool.map(another_worker_function, data_iterator, num_concurrent=5)
|
||||||
...
|
...
|
||||||
pool.lock()
|
|
||||||
await pool.gather_and_close()
|
await pool.gather_and_close()
|
||||||
|
|
||||||
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
|
Calling the :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` method this way ensures that there will **always** -- i.e. at any given moment in time -- be exactly 5 tasks working concurrently on our data (assuming no other pool interaction).
|
||||||
|
|
||||||
|
The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`)
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
The :py:meth:`.gather_and_close() <asyncio_taskpool.pool.BaseTaskPool.gather_and_close>` line will block until **all the data** has been consumed. (see :ref:`blocking-pool-methods`)
|
Neither :py:meth:`.apply() <asyncio_taskpool.pool.TaskPool.apply>` nor :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` return coroutines. When they are called, the task pool immediately begins scheduling new tasks to run. No :code:`await` needed.
|
||||||
|
|
||||||
It can't get any simpler than that, can it? So glad you asked...
|
It can't get any simpler than that, can it? So glad you asked...
|
||||||
|
|
||||||
@ -163,13 +164,13 @@ Let's take the :ref:`queue worker example <queue-worker-function>` from before.
|
|||||||
:caption: main.py
|
:caption: main.py
|
||||||
|
|
||||||
from asyncio_taskpool import SimpleTaskPool
|
from asyncio_taskpool import SimpleTaskPool
|
||||||
from .work import another_worker_function
|
from .work import queue_worker_function
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
...
|
...
|
||||||
pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out))
|
pool = SimpleTaskPool(queue_worker_function, args=(q_in, q_out))
|
||||||
await pool.start(5)
|
pool.start(5)
|
||||||
...
|
...
|
||||||
pool.stop_all()
|
pool.stop_all()
|
||||||
...
|
...
|
||||||
@ -229,7 +230,4 @@ The only method of a pool that one should **always** assume to be blocking is :p
|
|||||||
|
|
||||||
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
|
One method to be aware of is :py:meth:`.flush() <asyncio_taskpool.pool.BaseTaskPool.flush>`. Since it will await only those tasks that the pool considers **ended** or **cancelled**, the blocking can only come from any callbacks that were provided for either of those situations.
|
||||||
|
|
||||||
In general, the act of adding tasks to a pool is non-blocking, no matter which particular methods are used. The only notable exception is when a limit on the pool size has been set and there is "not enough room" to add a task. In this case, both :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>` and :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` will block until the desired number of new tasks found room in the pool (either because other tasks have ended or because the pool size was increased).
|
All methods that add tasks to a pool, i.e. :py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants), :py:meth:`TaskPool.apply() <asyncio_taskpool.pool.TaskPool.apply>` and :py:meth:`SimpleTaskPool.start() <asyncio_taskpool.pool.SimpleTaskPool.start>`, are non-blocking by design. They all make use of "meta tasks" under the hood and return immediately. It is important however, to realize that just because they return, does not mean that any actual tasks have been spawned. For example, if a pool size limit was set and there was "no more room" in the pool when :py:meth:`.map() <asyncio_taskpool.pool.TaskPool.map>` was called, there is **no guarantee** that even a single task has started, when it returns.
|
||||||
|
|
||||||
:py:meth:`TaskPool.map() <asyncio_taskpool.pool.TaskPool.map>` (and its variants) will **never** block. Since it makes use of "meta-tasks" under the hood, it will always return immediately. However, if the pool was full when it was called, there is **no guarantee** that even a single task has started, when the method returns.
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 1.0.0-beta
|
version = 1.1.4
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
@ -11,7 +11,7 @@ url = https://git.fajnberg.de/daniil/asyncio-taskpool
|
|||||||
project_urls =
|
project_urls =
|
||||||
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
|
Bug Tracker = https://github.com/daniil-berg/asyncio-taskpool/issues
|
||||||
classifiers =
|
classifiers =
|
||||||
Development Status :: 4 - Beta
|
Development Status :: 5 - Production/Stable
|
||||||
Programming Language :: Python :: 3
|
Programming Language :: Python :: 3
|
||||||
Operating System :: OS Independent
|
Operating System :: OS Independent
|
||||||
License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
|
License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
|
||||||
|
@ -35,7 +35,7 @@ __all__ = []
|
|||||||
|
|
||||||
CLIENT_CLASS = 'client_class'
|
CLIENT_CLASS = 'client_class'
|
||||||
UNIX, TCP = 'unix', 'tcp'
|
UNIX, TCP = 'unix', 'tcp'
|
||||||
SOCKET_PATH = 'path'
|
SOCKET_PATH = 'socket_path'
|
||||||
HOST, PORT = 'host', 'port'
|
HOST, PORT = 'host', 'port'
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,6 +85,7 @@ class ControlClient(ABC):
|
|||||||
"""
|
"""
|
||||||
self._connected = True
|
self._connected = True
|
||||||
writer.write(json.dumps(self._client_info()).encode())
|
writer.write(json.dumps(self._client_info()).encode())
|
||||||
|
writer.write(b'\n')
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
print("Connected to", (await reader.read(SESSION_MSG_BYTES)).decode())
|
||||||
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
print("Type '-h' to get help and usage instructions for all available commands.\n")
|
||||||
@ -97,21 +98,22 @@ class ControlClient(ABC):
|
|||||||
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
writer: The `asyncio.StreamWriter` returned by the `_open_connection()` method
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`None`, if either `Ctrl+C` was hit, or the user wants the client to disconnect;
|
`None`, if either `Ctrl+C` was hit, an empty or whitespace-only string was entered, or the user wants the
|
||||||
otherwise, the user's input, stripped of leading and trailing spaces and converted to lowercase.
|
client to disconnect; otherwise, returns the user's input, stripped of leading and trailing spaces and
|
||||||
|
converted to lowercase.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg = input("> ").strip().lower()
|
cmd = input("> ").strip().lower()
|
||||||
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
|
except EOFError: # Ctrl+D shall be equivalent to the :const:`CLIENT_EXIT` command.
|
||||||
msg = CLIENT_EXIT
|
cmd = CLIENT_EXIT
|
||||||
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
except KeyboardInterrupt: # Ctrl+C shall simply reset to the input prompt.
|
||||||
print()
|
print()
|
||||||
return
|
return
|
||||||
if msg == CLIENT_EXIT:
|
if cmd == CLIENT_EXIT:
|
||||||
writer.close()
|
writer.close()
|
||||||
self._connected = False
|
self._connected = False
|
||||||
return
|
return
|
||||||
return msg
|
return cmd or None # will be None if `cmd` is an empty string
|
||||||
|
|
||||||
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
async def _interact(self, reader: StreamReader, writer: StreamWriter) -> None:
|
||||||
"""
|
"""
|
||||||
@ -130,6 +132,7 @@ class ControlClient(ABC):
|
|||||||
try:
|
try:
|
||||||
# Send the command to the server.
|
# Send the command to the server.
|
||||||
writer.write(cmd.encode())
|
writer.write(cmd.encode())
|
||||||
|
writer.write(b'\n')
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
@ -17,18 +17,21 @@ If not, see <https://www.gnu.org/licenses/>."""
|
|||||||
__doc__ = """
|
__doc__ = """
|
||||||
Definition of the :class:`ControlParser` used in a
|
Definition of the :class:`ControlParser` used in a
|
||||||
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
|
:class:`ControlSession <asyncio_taskpool.control.session.ControlSession>`.
|
||||||
|
|
||||||
|
It should not be considered part of the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, SUPPRESS
|
import logging
|
||||||
|
from argparse import Action, ArgumentParser, ArgumentDefaultsHelpFormatter, HelpFormatter, ArgumentTypeError, SUPPRESS
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from asyncio.streams import StreamWriter
|
|
||||||
from inspect import Parameter, getmembers, isfunction, signature
|
from inspect import Parameter, getmembers, isfunction, signature
|
||||||
|
from io import StringIO
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
|
from typing import Any, Callable, Container, Dict, Iterable, Set, Type, TypeVar
|
||||||
|
|
||||||
from ..exceptions import HelpRequested, ParserError
|
from ..exceptions import HelpRequested, ParserError
|
||||||
from ..internals.constants import CLIENT_INFO, CMD, STREAM_WRITER
|
from ..internals.constants import CLIENT_INFO, CMD
|
||||||
from ..internals.helpers import get_first_doc_line, resolve_dotted_path
|
from ..internals.helpers import get_first_doc_line, resolve_dotted_path
|
||||||
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
||||||
|
|
||||||
@ -36,6 +39,9 @@ from ..internals.types import ArgsT, CancelCB, CoroutineFunc, EndCB, KwArgsT
|
|||||||
__all__ = ['ControlParser']
|
__all__ = ['ControlParser']
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
FmtCls = TypeVar('FmtCls', bound=Type[HelpFormatter])
|
||||||
ParsersDict = Dict[str, 'ControlParser']
|
ParsersDict = Dict[str, 'ControlParser']
|
||||||
|
|
||||||
@ -48,8 +54,8 @@ class ControlParser(ArgumentParser):
|
|||||||
"""
|
"""
|
||||||
Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
|
Subclass of the standard :code:`argparse.ArgumentParser` for pool control.
|
||||||
|
|
||||||
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a `StreamWriter`
|
Such a parser is not supposed to ever print to stdout/stderr, but instead direct all messages to a file-like
|
||||||
instance passed to it during initialization.
|
`StringIO` instance passed to it during initialization.
|
||||||
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
Furthermore, it requires defining the width of the terminal, to adjust help formatting to the terminal size of a
|
||||||
connected client.
|
connected client.
|
||||||
Finally, it offers some convenience methods and makes use of custom exceptions.
|
Finally, it offers some convenience methods and makes use of custom exceptions.
|
||||||
@ -83,25 +89,23 @@ class ControlParser(ArgumentParser):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
return ClientHelpFormatter
|
return ClientHelpFormatter
|
||||||
|
|
||||||
def __init__(self, stream_writer: StreamWriter, terminal_width: int = None, **kwargs) -> None:
|
def __init__(self, stream: StringIO, terminal_width: int = None, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Sets some internal attributes in addition to the base class.
|
Sets some internal attributes in addition to the base class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_writer:
|
stream:
|
||||||
The instance of the :class:`asyncio.StreamWriter` to use for message output.
|
A file-like I/O object to use for message output.
|
||||||
terminal_width (optional):
|
terminal_width (optional):
|
||||||
The terminal width to use for all message formatting. By default the :code:`columns` attribute from
|
The terminal width to use for all message formatting. By default the :code:`columns` attribute from
|
||||||
:func:`shutil.get_terminal_size` is taken.
|
:func:`shutil.get_terminal_size` is taken.
|
||||||
**kwargs(optional):
|
**kwargs(optional):
|
||||||
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
|
Passed to the parent class constructor. The exception is the `formatter_class` parameter: Even if a
|
||||||
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
|
class is specified, it will always be subclassed in the :meth:`help_formatter_factory`.
|
||||||
Also, by default, `exit_on_error` is set to `False` (as opposed to how the parent class handles it).
|
|
||||||
"""
|
"""
|
||||||
self._stream_writer: StreamWriter = stream_writer
|
self._stream: StringIO = stream
|
||||||
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
|
self._terminal_width: int = terminal_width if terminal_width is not None else get_terminal_size().columns
|
||||||
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
|
kwargs['formatter_class'] = self.help_formatter_factory(self._terminal_width, kwargs.get('formatter_class'))
|
||||||
kwargs.setdefault('exit_on_error', False)
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._flags: Set[str] = set()
|
self._flags: Set[str] = set()
|
||||||
self._commands = None
|
self._commands = None
|
||||||
@ -190,7 +194,7 @@ class ControlParser(ArgumentParser):
|
|||||||
Dictionary mapping class member names to the (sub-)parsers created from them.
|
Dictionary mapping class member names to the (sub-)parsers created from them.
|
||||||
"""
|
"""
|
||||||
parsers: ParsersDict = {}
|
parsers: ParsersDict = {}
|
||||||
common_kwargs = {STREAM_WRITER: self._stream_writer, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
|
common_kwargs = {'stream': self._stream, CLIENT_INFO.TERMINAL_WIDTH: self._terminal_width}
|
||||||
for name, member in getmembers(cls):
|
for name, member in getmembers(cls):
|
||||||
if name in omit_members or (name.startswith('_') and public_only):
|
if name in omit_members or (name.startswith('_') and public_only):
|
||||||
continue
|
continue
|
||||||
@ -210,9 +214,9 @@ class ControlParser(ArgumentParser):
|
|||||||
return self._commands
|
return self._commands
|
||||||
|
|
||||||
def _print_message(self, message: str, *args, **kwargs) -> None:
|
def _print_message(self, message: str, *args, **kwargs) -> None:
|
||||||
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream writer."""
|
"""This is overridden to ensure that no messages are sent to stdout/stderr, but always to the stream buffer."""
|
||||||
if message:
|
if message:
|
||||||
self._stream_writer.write(message.encode())
|
self._stream.write(message)
|
||||||
|
|
||||||
def exit(self, status: int = 0, message: str = None) -> None:
|
def exit(self, status: int = 0, message: str = None) -> None:
|
||||||
"""This is overridden to prevent system exit to be invoked."""
|
"""This is overridden to prevent system exit to be invoked."""
|
||||||
@ -300,8 +304,21 @@ def _get_arg_type_wrapper(cls: Type) -> Callable[[Any], Any]:
|
|||||||
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
|
Returns a wrapper for the constructor of `cls` to avoid a ValueError being raised on suppressed arguments.
|
||||||
|
|
||||||
See: https://bugs.python.org/issue36078
|
See: https://bugs.python.org/issue36078
|
||||||
|
|
||||||
|
In addition, the type conversion wrapper catches exceptions not handled properly by the parser, logs them, and
|
||||||
|
turns them into `ArgumentTypeError` exceptions the parser can propagate to the client.
|
||||||
"""
|
"""
|
||||||
def wrapper(arg: Any) -> Any: return arg if arg is SUPPRESS else cls(arg)
|
def wrapper(arg: Any) -> Any:
|
||||||
|
if arg is SUPPRESS:
|
||||||
|
return arg
|
||||||
|
try:
|
||||||
|
return cls(arg)
|
||||||
|
except (ArgumentTypeError, TypeError, ValueError):
|
||||||
|
raise # handled properly by the parser and propagated to the client anyway
|
||||||
|
except Exception as e:
|
||||||
|
text = f"{e.__class__.__name__} occurred in parser trying to convert type: {cls.__name__}({repr(arg)})"
|
||||||
|
log.exception(text)
|
||||||
|
raise ArgumentTypeError(text) # propagate to the client
|
||||||
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
|
# Copy the name of the class to maintain useful help messages when incorrect arguments are passed.
|
||||||
wrapper.__name__ = cls.__name__
|
wrapper.__name__ = cls.__name__
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -31,6 +31,7 @@ from typing import Optional, Union
|
|||||||
from .client import ControlClient, TCPControlClient, UnixControlClient
|
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||||
from .session import ControlSession
|
from .session import ControlSession
|
||||||
from ..pool import AnyTaskPoolT
|
from ..pool import AnyTaskPoolT
|
||||||
|
from ..internals.helpers import classmethod
|
||||||
from ..internals.types import ConnectedCallbackT, PathT
|
from ..internals.types import ConnectedCallbackT, PathT
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,8 @@ If not, see <https://www.gnu.org/licenses/>."""
|
|||||||
|
|
||||||
__doc__ = """
|
__doc__ = """
|
||||||
Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
|
Definition of the :class:`ControlSession` used by a :class:`ControlServer`.
|
||||||
|
|
||||||
|
It should not be considered part of the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -24,12 +26,13 @@ import json
|
|||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
from inspect import isfunction, signature
|
from inspect import isfunction, signature
|
||||||
|
from io import StringIO
|
||||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from .parser import ControlParser
|
from .parser import ControlParser
|
||||||
from ..exceptions import CommandError, HelpRequested, ParserError
|
from ..exceptions import CommandError, HelpRequested, ParserError
|
||||||
from ..pool import TaskPool, SimpleTaskPool
|
from ..pool import TaskPool, SimpleTaskPool
|
||||||
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK, SESSION_MSG_BYTES, STREAM_WRITER
|
from ..internals.constants import CLIENT_INFO, CMD, CMD_OK
|
||||||
from ..internals.helpers import return_or_exception
|
from ..internals.helpers import return_or_exception
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -72,6 +75,7 @@ class ControlSession:
|
|||||||
self._reader: StreamReader = reader
|
self._reader: StreamReader = reader
|
||||||
self._writer: StreamWriter = writer
|
self._writer: StreamWriter = writer
|
||||||
self._parser: Optional[ControlParser] = None
|
self._parser: Optional[ControlParser] = None
|
||||||
|
self._response_buffer: StringIO = StringIO()
|
||||||
|
|
||||||
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
|
async def _exec_method_and_respond(self, method: Callable, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
@ -99,7 +103,7 @@ class ControlSession:
|
|||||||
elif param.kind == param.VAR_POSITIONAL:
|
elif param.kind == param.VAR_POSITIONAL:
|
||||||
var_pos = kwargs.pop(param.name)
|
var_pos = kwargs.pop(param.name)
|
||||||
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
|
output = await return_or_exception(method, *normal_pos, *var_pos, **kwargs)
|
||||||
self._writer.write(CMD_OK if output is None else str(output).encode())
|
self._response_buffer.write(CMD_OK.decode() if output is None else str(output))
|
||||||
|
|
||||||
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
|
async def _exec_property_and_respond(self, prop: property, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
@ -118,10 +122,10 @@ class ControlSession:
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
log.debug("%s sets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fset.__name__)
|
||||||
await return_or_exception(prop.fset, self._pool, **kwargs)
|
await return_or_exception(prop.fset, self._pool, **kwargs)
|
||||||
self._writer.write(CMD_OK)
|
self._response_buffer.write(CMD_OK.decode())
|
||||||
else:
|
else:
|
||||||
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
log.debug("%s gets %s.%s", self._client_class_name, self._pool.__class__.__name__, prop.fget.__name__)
|
||||||
self._writer.write(str(await return_or_exception(prop.fget, self._pool)).encode())
|
self._response_buffer.write(str(await return_or_exception(prop.fget, self._pool)))
|
||||||
|
|
||||||
async def client_handshake(self) -> None:
|
async def client_handshake(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -130,10 +134,11 @@ class ControlSession:
|
|||||||
Client info is retrieved, server info is sent back, and the
|
Client info is retrieved, server info is sent back, and the
|
||||||
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
:class:`ControlParser <asyncio_taskpool.control.parser.ControlParser>` is set up.
|
||||||
"""
|
"""
|
||||||
client_info = json.loads((await self._reader.read(SESSION_MSG_BYTES)).decode().strip())
|
msg = (await self._reader.readline()).decode().strip()
|
||||||
|
client_info = json.loads(msg)
|
||||||
log.debug("%s connected", self._client_class_name)
|
log.debug("%s connected", self._client_class_name)
|
||||||
parser_kwargs = {
|
parser_kwargs = {
|
||||||
STREAM_WRITER: self._writer,
|
'stream': self._response_buffer,
|
||||||
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
|
CLIENT_INFO.TERMINAL_WIDTH: client_info[CLIENT_INFO.TERMINAL_WIDTH],
|
||||||
'prog': '',
|
'prog': '',
|
||||||
'usage': f'[-h] [{CMD}] ...'
|
'usage': f'[-h] [{CMD}] ...'
|
||||||
@ -142,7 +147,7 @@ class ControlSession:
|
|||||||
self._parser.add_subparsers(title="Commands",
|
self._parser.add_subparsers(title="Commands",
|
||||||
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
metavar="(A command followed by '-h' or '--help' will show command-specific help.)")
|
||||||
self._parser.add_class_commands(self._pool.__class__)
|
self._parser.add_class_commands(self._pool.__class__)
|
||||||
self._writer.write(str(self._pool).encode())
|
self._writer.write(str(self._pool).encode() + b'\n')
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
|
||||||
async def _parse_command(self, msg: str) -> None:
|
async def _parse_command(self, msg: str) -> None:
|
||||||
@ -160,7 +165,7 @@ class ControlSession:
|
|||||||
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
kwargs = vars(self._parser.parse_args(msg.split(' ')))
|
||||||
except ArgumentError as e:
|
except ArgumentError as e:
|
||||||
log.debug("%s got an ArgumentError", self._client_class_name)
|
log.debug("%s got an ArgumentError", self._client_class_name)
|
||||||
self._writer.write(str(e).encode())
|
self._response_buffer.write(str(e))
|
||||||
return
|
return
|
||||||
except (HelpRequested, ParserError):
|
except (HelpRequested, ParserError):
|
||||||
log.debug("%s received usage help", self._client_class_name)
|
log.debug("%s received usage help", self._client_class_name)
|
||||||
@ -171,7 +176,7 @@ class ControlSession:
|
|||||||
elif isinstance(command, property):
|
elif isinstance(command, property):
|
||||||
await self._exec_property_and_respond(command, **kwargs)
|
await self._exec_property_and_respond(command, **kwargs)
|
||||||
else:
|
else:
|
||||||
self._writer.write(str(CommandError(f"Unknown command object: {command}")).encode())
|
self._response_buffer.write(str(CommandError(f"Unknown command object: {command}")))
|
||||||
|
|
||||||
async def listen(self) -> None:
|
async def listen(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -183,9 +188,13 @@ class ControlSession:
|
|||||||
It will obviously block indefinitely.
|
It will obviously block indefinitely.
|
||||||
"""
|
"""
|
||||||
while self._control_server.is_serving():
|
while self._control_server.is_serving():
|
||||||
msg = (await self._reader.read(SESSION_MSG_BYTES)).decode().strip()
|
msg = (await self._reader.readline()).decode().strip()
|
||||||
if not msg:
|
if not msg:
|
||||||
log.debug("%s disconnected", self._client_class_name)
|
log.debug("%s disconnected", self._client_class_name)
|
||||||
break
|
break
|
||||||
await self._parse_command(msg)
|
await self._parse_command(msg)
|
||||||
|
response = self._response_buffer.getvalue() + "\n"
|
||||||
|
self._response_buffer.seek(0)
|
||||||
|
self._response_buffer.truncate()
|
||||||
|
self._writer.write(response.encode())
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
|
@ -51,10 +51,6 @@ class InvalidGroupName(PoolException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PoolStillUnlocked(PoolException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NotCoroutine(PoolException):
|
class NotCoroutine(PoolException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -21,15 +21,17 @@ This module should **not** be considered part of the public API.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = 'asyncio_taskpool'
|
PACKAGE_NAME = 'asyncio_taskpool'
|
||||||
|
|
||||||
|
PYTHON_BEFORE_39 = sys.version_info[:2] < (3, 9)
|
||||||
|
|
||||||
DEFAULT_TASK_GROUP = 'default'
|
DEFAULT_TASK_GROUP = 'default'
|
||||||
|
|
||||||
DATETIME_FORMAT = '%Y-%m-%d_%H-%M-%S'
|
|
||||||
|
|
||||||
SESSION_MSG_BYTES = 1024 * 100
|
SESSION_MSG_BYTES = 1024 * 100
|
||||||
|
|
||||||
STREAM_WRITER = 'stream_writer'
|
|
||||||
CMD = 'command'
|
CMD = 'command'
|
||||||
CMD_OK = b"ok"
|
CMD_OK = b"ok"
|
||||||
|
|
||||||
|
@ -19,12 +19,13 @@ Miscellaneous helper functions. None of these should be considered part of the p
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import builtins
|
||||||
from asyncio.coroutines import iscoroutinefunction
|
from asyncio.coroutines import iscoroutinefunction
|
||||||
from asyncio.queues import Queue
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from inspect import getdoc
|
from inspect import getdoc
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Callable, Optional, Type, Union
|
||||||
|
|
||||||
|
from .constants import PYTHON_BEFORE_39
|
||||||
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
from .types import T, AnyCallableT, ArgsT, KwArgsT
|
||||||
|
|
||||||
|
|
||||||
@ -86,11 +87,6 @@ def star_function(function: AnyCallableT, arg: Any, arg_stars: int = 0) -> T:
|
|||||||
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
|
raise ValueError(f"Invalid argument arg_stars={arg_stars}; must be 0, 1, or 2.")
|
||||||
|
|
||||||
|
|
||||||
async def join_queue(q: Queue) -> None:
|
|
||||||
"""Wrapper function around the join method of an `asyncio.Queue` instance."""
|
|
||||||
await q.join()
|
|
||||||
|
|
||||||
|
|
||||||
def get_first_doc_line(obj: object) -> str:
|
def get_first_doc_line(obj: object) -> str:
|
||||||
"""Takes an object and returns the first (non-empty) line of its docstring."""
|
"""Takes an object and returns the first (non-empty) line of its docstring."""
|
||||||
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
return getdoc(obj).strip().split("\n", 1)[0].strip()
|
||||||
@ -137,3 +133,25 @@ def resolve_dotted_path(dotted_path: str) -> object:
|
|||||||
import_module(module_name)
|
import_module(module_name)
|
||||||
found = getattr(found, name)
|
found = getattr(found, name)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
class ClassMethodWorkaround:
|
||||||
|
"""Dirty workaround to make the `@classmethod` decorator work with properties."""
|
||||||
|
|
||||||
|
def __init__(self, method_or_property: Union[Callable, property]) -> None:
|
||||||
|
if isinstance(method_or_property, property):
|
||||||
|
self._getter = method_or_property.fget
|
||||||
|
else:
|
||||||
|
self._getter = method_or_property
|
||||||
|
|
||||||
|
def __get__(self, obj: Union[T, None], cls: Union[Type[T], None]) -> Any:
|
||||||
|
if obj is None:
|
||||||
|
return self._getter(cls)
|
||||||
|
return self._getter(obj)
|
||||||
|
|
||||||
|
|
||||||
|
# Starting with Python 3.9, this is thankfully no longer necessary.
|
||||||
|
if PYTHON_BEFORE_39:
|
||||||
|
classmethod = ClassMethodWorkaround
|
||||||
|
else:
|
||||||
|
classmethod = builtins.classmethod
|
||||||
|
@ -23,7 +23,7 @@ This module should **not** be considered part of the public API.
|
|||||||
|
|
||||||
from asyncio.streams import StreamReader, StreamWriter
|
from asyncio.streams import StreamReader, StreamWriter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, TypeVar, Union
|
from typing import Any, Awaitable, Callable, Coroutine, Iterable, Mapping, Tuple, TypeVar, Union
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -31,8 +31,8 @@ T = TypeVar('T')
|
|||||||
ArgsT = Iterable[Any]
|
ArgsT = Iterable[Any]
|
||||||
KwArgsT = Mapping[str, Any]
|
KwArgsT = Mapping[str, Any]
|
||||||
|
|
||||||
AnyCallableT = Callable[[...], Union[T, Awaitable[T]]]
|
AnyCallableT = Callable[..., Union[T, Awaitable[T]]]
|
||||||
CoroutineFunc = Callable[[...], Awaitable[Any]]
|
CoroutineFunc = Callable[..., Coroutine]
|
||||||
|
|
||||||
EndCB = Callable
|
EndCB = Callable
|
||||||
CancelCB = Callable
|
CancelCB = Callable
|
||||||
|
@ -28,21 +28,19 @@ For further details about the classes check their respective documentation.
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from asyncio.queues import QueueEmpty
|
|
||||||
from asyncio.tasks import Task, create_task, gather
|
from asyncio.tasks import Task, create_task, gather
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
|
||||||
from math import inf
|
from math import inf
|
||||||
from typing import Any, Awaitable, Dict, Iterable, Iterator, List, Set, Union
|
from typing import Any, Awaitable, Dict, Iterable, List, Set, Union
|
||||||
|
|
||||||
from . import exceptions
|
from . import exceptions
|
||||||
from .queue_context import Queue
|
from .internals.constants import DEFAULT_TASK_GROUP, PYTHON_BEFORE_39
|
||||||
from .internals.constants import DEFAULT_TASK_GROUP, DATETIME_FORMAT
|
|
||||||
from .internals.group_register import TaskGroupRegister
|
from .internals.group_register import TaskGroupRegister
|
||||||
from .internals.helpers import execute_optional, star_function, join_queue
|
from .internals.helpers import execute_optional, star_function
|
||||||
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
from .internals.types import ArgsT, KwArgsT, CoroutineFunc, EndCB, CancelCB
|
||||||
|
|
||||||
|
|
||||||
@ -74,7 +72,7 @@ class BaseTaskPool:
|
|||||||
|
|
||||||
# Initialize flags; immutably set the name.
|
# Initialize flags; immutably set the name.
|
||||||
self._locked: bool = False
|
self._locked: bool = False
|
||||||
self._closed: bool = False
|
self._closed: Event = Event()
|
||||||
self._name: str = name
|
self._name: str = name
|
||||||
|
|
||||||
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
||||||
@ -82,11 +80,14 @@ class BaseTaskPool:
|
|||||||
self._tasks_cancelled: Dict[int, Task] = {}
|
self._tasks_cancelled: Dict[int, Task] = {}
|
||||||
self._tasks_ended: Dict[int, Task] = {}
|
self._tasks_ended: Dict[int, Task] = {}
|
||||||
|
|
||||||
# These next three attributes act as synchronisation primitives necessary for managing the pool.
|
# These next two attributes act as synchronisation primitives necessary for managing the pool.
|
||||||
self._before_gathering: List[Awaitable] = []
|
|
||||||
self._enough_room: Semaphore = Semaphore()
|
self._enough_room: Semaphore = Semaphore()
|
||||||
self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
|
self._task_groups: Dict[str, TaskGroupRegister[int]] = {}
|
||||||
|
|
||||||
|
# Mapping task group names to sets of meta tasks, and a bucket for cancelled meta tasks.
|
||||||
|
self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
|
||||||
|
self._meta_tasks_cancelled: Set[Task] = set()
|
||||||
|
|
||||||
# Finish with method/functions calls that add the pool to the internal list of pools, set its initial size,
|
# Finish with method/functions calls that add the pool to the internal list of pools, set its initial size,
|
||||||
# and issue a log message.
|
# and issue a log message.
|
||||||
self._idx: int = self._add_pool(self)
|
self._idx: int = self._add_pool(self)
|
||||||
@ -220,7 +221,7 @@ class BaseTaskPool:
|
|||||||
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
||||||
if function and not iscoroutinefunction(function):
|
if function and not iscoroutinefunction(function):
|
||||||
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
||||||
if self._closed:
|
if self._closed.is_set():
|
||||||
raise exceptions.PoolIsClosed("You must use another pool")
|
raise exceptions.PoolIsClosed("You must use another pool")
|
||||||
if self._locked and not ignore_lock:
|
if self._locked and not ignore_lock:
|
||||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||||
@ -326,6 +327,8 @@ class BaseTaskPool:
|
|||||||
"""
|
"""
|
||||||
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
|
self._check_start(awaitable=awaitable, ignore_lock=ignore_lock)
|
||||||
await self._enough_room.acquire()
|
await self._enough_room.acquire()
|
||||||
|
# TODO: Make sure that cancellation (group or pool) interrupts this method after context switching!
|
||||||
|
# Possibly make use of the task group register for that.
|
||||||
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
|
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||||
async with group_reg:
|
async with group_reg:
|
||||||
task_id = self._num_started
|
task_id = self._num_started
|
||||||
@ -358,6 +361,23 @@ class BaseTaskPool:
|
|||||||
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
raise exceptions.AlreadyEnded(f"{self._task_name(task_id)} has finished running")
|
||||||
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
raise exceptions.InvalidTaskID(f"No task with ID {task_id} found in {self}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_cancel_kw(msg: Union[str, None]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary to unpack in a `Task.cancel()` method.
|
||||||
|
|
||||||
|
This method exists to ensure proper compatibility with older Python versions.
|
||||||
|
If `msg` is `None`, an empty dictionary is returned.
|
||||||
|
If `PYTHON_BEFORE_39` is `True` a warning is issued before returning an empty dictionary.
|
||||||
|
Otherwise the keyword dictionary contains the `msg` parameter.
|
||||||
|
"""
|
||||||
|
if msg is None:
|
||||||
|
return {}
|
||||||
|
if PYTHON_BEFORE_39:
|
||||||
|
warnings.warn("Parameter `msg` is not available with Python versions before 3.9 and will be ignored.")
|
||||||
|
return {}
|
||||||
|
return {'msg': msg}
|
||||||
|
|
||||||
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
def cancel(self, *task_ids: int, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels the tasks with the specified IDs.
|
Cancels the tasks with the specified IDs.
|
||||||
@ -376,132 +396,9 @@ class BaseTaskPool:
|
|||||||
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
`InvalidTaskID`: One of the `task_ids` is not known to the pool.
|
||||||
"""
|
"""
|
||||||
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
tasks = [self._get_running_task(task_id) for task_id in task_ids]
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel(msg=msg)
|
task.cancel(**kw)
|
||||||
|
|
||||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
|
||||||
"""
|
|
||||||
Removes all tasks from the specified group and cancels them.
|
|
||||||
|
|
||||||
Does nothing to tasks, that are no longer running.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group_name: The name of the group of tasks that shall be cancelled.
|
|
||||||
group_reg: The task group register object containing the task IDs.
|
|
||||||
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
|
||||||
"""
|
|
||||||
while group_reg:
|
|
||||||
try:
|
|
||||||
self._tasks_running[group_reg.pop()].cancel(msg=msg)
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
|
||||||
|
|
||||||
async def cancel_group(self, group_name: str, msg: str = None) -> None:
|
|
||||||
"""
|
|
||||||
Cancels an entire group of tasks.
|
|
||||||
|
|
||||||
The task group is subsequently forgotten by the pool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group_name: The name of the group of tasks that shall be cancelled.
|
|
||||||
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
`InvalidGroupName`: if no task group named `group_name` exists in the pool.
|
|
||||||
"""
|
|
||||||
log.debug("%s cancelling tasks in group %s", str(self), group_name)
|
|
||||||
try:
|
|
||||||
group_reg = self._task_groups.pop(group_name)
|
|
||||||
except KeyError:
|
|
||||||
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
|
||||||
async with group_reg:
|
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
|
||||||
log.debug("%s forgot task group %s", str(self), group_name)
|
|
||||||
|
|
||||||
async def cancel_all(self, msg: str = None) -> None:
|
|
||||||
"""
|
|
||||||
Cancels all tasks still running within the pool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
|
||||||
"""
|
|
||||||
log.warning("%s cancelling all tasks!", str(self))
|
|
||||||
while self._task_groups:
|
|
||||||
group_name, group_reg = self._task_groups.popitem()
|
|
||||||
async with group_reg:
|
|
||||||
self._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
|
||||||
|
|
||||||
async def flush(self, return_exceptions: bool = False):
|
|
||||||
"""
|
|
||||||
Gathers (i.e. awaits) all ended/cancelled tasks in the pool.
|
|
||||||
|
|
||||||
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
|
|
||||||
`Task` objects.
|
|
||||||
|
|
||||||
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
|
|
||||||
registered for the tasks block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
return_exceptions (optional): Passed directly into `gather`.
|
|
||||||
"""
|
|
||||||
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), return_exceptions=return_exceptions)
|
|
||||||
self._tasks_ended.clear()
|
|
||||||
self._tasks_cancelled.clear()
|
|
||||||
|
|
||||||
async def gather_and_close(self, return_exceptions: bool = False):
|
|
||||||
"""
|
|
||||||
Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
|
|
||||||
|
|
||||||
After this method returns, no more tasks can be started in the pool.
|
|
||||||
|
|
||||||
:meth:`lock` must have been called prior to this.
|
|
||||||
|
|
||||||
This method may block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of the
|
|
||||||
callbacks registered for a task blocks for whatever reason.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
return_exceptions (optional): Passed directly into `gather`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
|
||||||
"""
|
|
||||||
if not self._locked:
|
|
||||||
raise exceptions.PoolStillUnlocked("Pool must be locked, before tasks can be gathered")
|
|
||||||
await gather(*self._before_gathering)
|
|
||||||
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
|
|
||||||
return_exceptions=return_exceptions)
|
|
||||||
self._tasks_ended.clear()
|
|
||||||
self._tasks_cancelled.clear()
|
|
||||||
self._tasks_running.clear()
|
|
||||||
self._before_gathering.clear()
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
|
|
||||||
class TaskPool(BaseTaskPool):
|
|
||||||
"""
|
|
||||||
General purpose task pool class.
|
|
||||||
|
|
||||||
Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
|
|
||||||
|
|
||||||
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
|
|
||||||
Tasks in the pool can all belong to the same coroutine function,
|
|
||||||
but they can also come from any number of different and unrelated coroutine functions.
|
|
||||||
|
|
||||||
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
|
|
||||||
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
|
|
||||||
|
|
||||||
Adding tasks blocks **only if** the pool is full at that moment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_QUEUE_END_SENTINEL = object()
|
|
||||||
|
|
||||||
def __init__(self, pool_size: int = inf, name: str = None) -> None:
|
|
||||||
super().__init__(pool_size=pool_size, name=name)
|
|
||||||
# In addition to all the attributes of the base class, we need a dictionary mapping task group names to sets of
|
|
||||||
# meta tasks that are/were running in the context of that group, and a bucked for cancelled meta tasks.
|
|
||||||
self._group_meta_tasks_running: Dict[str, Set[Task]] = {}
|
|
||||||
self._meta_tasks_cancelled: Set[Task] = set()
|
|
||||||
|
|
||||||
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
def _cancel_group_meta_tasks(self, group_name: str) -> None:
|
||||||
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
"""Cancels and forgets all meta tasks associated with the task group named `group_name`."""
|
||||||
@ -514,44 +411,67 @@ class TaskPool(BaseTaskPool):
|
|||||||
self._meta_tasks_cancelled.update(meta_tasks)
|
self._meta_tasks_cancelled.update(meta_tasks)
|
||||||
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
log.debug("%s cancelled and forgot meta tasks from group %s", str(self), group_name)
|
||||||
|
|
||||||
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, msg: str = None) -> None:
|
def _cancel_and_remove_all_from_group(self, group_name: str, group_reg: TaskGroupRegister, **cancel_kw) -> None:
|
||||||
"""See base class."""
|
"""
|
||||||
self._cancel_group_meta_tasks(group_name)
|
Removes all tasks from the specified group and cancels them.
|
||||||
super()._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg)
|
|
||||||
|
|
||||||
async def cancel_group(self, group_name: str, msg: str = None) -> None:
|
Does nothing to tasks, that are no longer running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_name: The name of the group of tasks that shall be cancelled.
|
||||||
|
group_reg: The task group register object containing the task IDs.
|
||||||
|
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
||||||
|
"""
|
||||||
|
self._cancel_group_meta_tasks(group_name)
|
||||||
|
while group_reg:
|
||||||
|
try:
|
||||||
|
self._tasks_running[group_reg.pop()].cancel(**cancel_kw)
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
log.debug("%s cancelled tasks from group %s", str(self), group_name)
|
||||||
|
|
||||||
|
def cancel_group(self, group_name: str, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels an entire group of tasks.
|
Cancels an entire group of tasks.
|
||||||
|
|
||||||
The task group is subsequently forgotten by the pool.
|
The task group is subsequently forgotten by the pool.
|
||||||
|
|
||||||
If any methods such as :meth:`map` launched meta tasks belonging to that group, these meta tasks are cancelled
|
If any methods such launched meta tasks belonging to that group, these meta tasks are cancelled before the
|
||||||
before the actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
|
actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
|
||||||
**never even start**. In the case of :meth:`map` this would mean that its `arg_iter` may be abandoned before it
|
**never even start**.
|
||||||
was fully consumed (if that is even possible).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_name: The name of the group of tasks (and meta tasks) that shall be cancelled.
|
group_name: The name of the group of tasks (and meta tasks) that shall be cancelled.
|
||||||
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
msg (optional): Passed to the `Task.cancel()` method of every task in the group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
`InvalidGroupName`: No task group named `group_name` exists in the pool.
|
`InvalidGroupName`: No task group named `group_name` exists in the pool.
|
||||||
"""
|
"""
|
||||||
await super().cancel_group(group_name=group_name, msg=msg)
|
log.debug("%s cancelling tasks in group %s", str(self), group_name)
|
||||||
|
try:
|
||||||
|
group_reg = self._task_groups.pop(group_name)
|
||||||
|
except KeyError:
|
||||||
|
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
|
log.debug("%s forgot task group %s", str(self), group_name)
|
||||||
|
|
||||||
async def cancel_all(self, msg: str = None) -> None:
|
def cancel_all(self, msg: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Cancels all tasks still running within the pool (including meta tasks).
|
Cancels all tasks still running within the pool (including meta tasks).
|
||||||
|
|
||||||
If any methods such as :meth:`map` launched meta tasks, these meta tasks are cancelled before the actual tasks
|
If any methods such launched meta tasks belonging to that group, these meta tasks are cancelled before the
|
||||||
are cancelled. This means that any tasks "queued" to be started by a meta task will **never even start**. In the
|
actual tasks are cancelled. This means that any tasks "queued" to be started by a meta task will
|
||||||
case of :meth:`map` this would mean that its `arg_iter` may be abandoned before it was fully consumed (if that
|
**never even start**.
|
||||||
is even possible).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg (optional): Passed to the `Task.cancel()` method of every task specified by the `task_ids`.
|
msg (optional): Passed to the `Task.cancel()` method of every task.
|
||||||
"""
|
"""
|
||||||
await super().cancel_all(msg=msg)
|
log.warning("%s cancelling all tasks!", str(self))
|
||||||
|
kw = self._get_cancel_kw(msg)
|
||||||
|
while self._task_groups:
|
||||||
|
group_name, group_reg = self._task_groups.popitem()
|
||||||
|
self._cancel_and_remove_all_from_group(group_name, group_reg, **kw)
|
||||||
|
|
||||||
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
def _pop_ended_meta_tasks(self) -> Set[Task]:
|
||||||
"""
|
"""
|
||||||
@ -584,7 +504,7 @@ class TaskPool(BaseTaskPool):
|
|||||||
Gathers (i.e. awaits) all ended/cancelled tasks in the pool.
|
Gathers (i.e. awaits) all ended/cancelled tasks in the pool.
|
||||||
|
|
||||||
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
|
The tasks are subsequently forgotten by the pool. This method exists mainly to free up memory of unneeded
|
||||||
`Task` objects. It also gets rid of unneeded meta tasks.
|
`Task` objects. It also gets rid of unneeded (ended/cancelled) meta tasks.
|
||||||
|
|
||||||
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
|
It blocks, **only if** any of the tasks block while catching a `asyncio.CancelledError` or any of the callbacks
|
||||||
registered for the tasks block.
|
registered for the tasks block.
|
||||||
@ -592,24 +512,23 @@ class TaskPool(BaseTaskPool):
|
|||||||
Args:
|
Args:
|
||||||
return_exceptions (optional): Passed directly into `gather`.
|
return_exceptions (optional): Passed directly into `gather`.
|
||||||
"""
|
"""
|
||||||
await super().flush(return_exceptions=return_exceptions)
|
|
||||||
with suppress(CancelledError):
|
with suppress(CancelledError):
|
||||||
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
|
await gather(*self._meta_tasks_cancelled, *self._pop_ended_meta_tasks(),
|
||||||
return_exceptions=return_exceptions)
|
return_exceptions=return_exceptions)
|
||||||
self._meta_tasks_cancelled.clear()
|
self._meta_tasks_cancelled.clear()
|
||||||
|
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), return_exceptions=return_exceptions)
|
||||||
|
self._tasks_ended.clear()
|
||||||
|
self._tasks_cancelled.clear()
|
||||||
|
|
||||||
async def gather_and_close(self, return_exceptions: bool = False):
|
async def gather_and_close(self, return_exceptions: bool = False):
|
||||||
"""
|
"""
|
||||||
Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
|
Gathers (i.e. awaits) **all** tasks in the pool, then closes it.
|
||||||
|
|
||||||
After this method returns, no more tasks can be started in the pool.
|
Once this method is called, no more tasks can be started in the pool.
|
||||||
|
|
||||||
The `lock()` method must have been called prior to this.
|
|
||||||
|
|
||||||
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
|
Note that this method may block indefinitely as long as any task in the pool is not done. This includes meta
|
||||||
tasks launched by methods such as :meth:`map`, which ends by itself, only once its `arg_iter` is fully consumed,
|
tasks launched by other methods, which may or may not even end by themselves. To avoid this, make sure to call
|
||||||
which may not even be possible (depending on what the iterable of arguments represents). If you want to avoid
|
:meth:`cancel_all` first.
|
||||||
this, make sure to call :meth:`cancel_all` prior to this.
|
|
||||||
|
|
||||||
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
This method may also block, if one of the tasks blocks while catching a `asyncio.CancelledError` or if any of
|
||||||
the callbacks registered for a task blocks for whatever reason.
|
the callbacks registered for a task blocks for whatever reason.
|
||||||
@ -620,62 +539,112 @@ class TaskPool(BaseTaskPool):
|
|||||||
Raises:
|
Raises:
|
||||||
`PoolStillUnlocked`: The pool has not been locked yet.
|
`PoolStillUnlocked`: The pool has not been locked yet.
|
||||||
"""
|
"""
|
||||||
# TODO: It probably makes sense to put this superclass method call at the end (see TODO in `_map`).
|
self.lock()
|
||||||
await super().gather_and_close(return_exceptions=return_exceptions)
|
not_cancelled_meta_tasks = (task for task_set in self._group_meta_tasks_running.values() for task in task_set)
|
||||||
not_cancelled_meta_tasks = set()
|
|
||||||
while self._group_meta_tasks_running:
|
|
||||||
_, meta_tasks = self._group_meta_tasks_running.popitem()
|
|
||||||
not_cancelled_meta_tasks.update(meta_tasks)
|
|
||||||
with suppress(CancelledError):
|
with suppress(CancelledError):
|
||||||
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
|
await gather(*self._meta_tasks_cancelled, *not_cancelled_meta_tasks, return_exceptions=return_exceptions)
|
||||||
self._meta_tasks_cancelled.clear()
|
self._meta_tasks_cancelled.clear()
|
||||||
|
self._group_meta_tasks_running.clear()
|
||||||
|
await gather(*self._tasks_ended.values(), *self._tasks_cancelled.values(), *self._tasks_running.values(),
|
||||||
|
return_exceptions=return_exceptions)
|
||||||
|
self._tasks_ended.clear()
|
||||||
|
self._tasks_cancelled.clear()
|
||||||
|
self._tasks_running.clear()
|
||||||
|
self._closed.set()
|
||||||
|
|
||||||
@staticmethod
|
async def until_closed(self) -> bool:
|
||||||
def _generate_group_name(prefix: str, coroutine_function: CoroutineFunc) -> str:
|
|
||||||
"""
|
"""
|
||||||
Creates a task group identifier that includes the current datetime.
|
Waits until the pool has been closed. (This method itself does **not** close the pool, but blocks until then.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`True` once the pool is closed.
|
||||||
|
"""
|
||||||
|
return await self._closed.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskPool(BaseTaskPool):
|
||||||
|
"""
|
||||||
|
General purpose task pool class.
|
||||||
|
|
||||||
|
Attempts to emulate part of the interface of `multiprocessing.pool.Pool` from the stdlib.
|
||||||
|
|
||||||
|
A `TaskPool` instance can manage an arbitrary number of concurrent tasks from any coroutine function.
|
||||||
|
Tasks in the pool can all belong to the same coroutine function,
|
||||||
|
but they can also come from any number of different and unrelated coroutine functions.
|
||||||
|
|
||||||
|
As long as there is room in the pool, more tasks can be added. (By default, there is no pool size limit.)
|
||||||
|
Each task started in the pool receives a unique ID, which can be used to cancel specific tasks at any moment.
|
||||||
|
|
||||||
|
Adding tasks blocks **only if** the pool is full at that moment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _generate_group_name(self, prefix: str, coroutine_function: CoroutineFunc) -> str:
|
||||||
|
"""
|
||||||
|
Creates a unique task group identifier.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prefix: The start of the name; will be followed by an underscore.
|
prefix: The start of the name; will be followed by an underscore.
|
||||||
coroutine_function: The function representing the task group.
|
coroutine_function: The function representing the task group.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The constructed 'prefix_function_datetime' string to name a task group.
|
The constructed '{prefix}-{name}-group-{idx}' string to name a task group.
|
||||||
|
(With `name` being the name of the `coroutine_function` and `idx` being an incrementing index.)
|
||||||
"""
|
"""
|
||||||
return f'{prefix}_{coroutine_function.__name__}_{datetime.now().strftime(DATETIME_FORMAT)}'
|
base_name = f'{prefix}-{coroutine_function.__name__}-group'
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
name = f'{base_name}-{i}'
|
||||||
|
if name not in self._task_groups.keys():
|
||||||
|
return name
|
||||||
|
i += 1
|
||||||
|
|
||||||
async def _apply_num(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
async def _apply_spawner(self, group_name: str, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
|
||||||
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
num: int = 1, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||||
"""
|
"""
|
||||||
Creates a coroutine with the supplied arguments and runs it as a new task in the pool.
|
Creates coroutines with the supplied arguments and runs them as new tasks in the pool.
|
||||||
|
|
||||||
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_name:
|
group_name:
|
||||||
Name of the task group to add the new task to.
|
Name of the task group to add the new tasks to.
|
||||||
func:
|
func:
|
||||||
The coroutine function to be run as a task within the task pool.
|
The coroutine function to be run in `num` tasks within the task pool.
|
||||||
args (optional):
|
args (optional):
|
||||||
The positional arguments to pass into the function call.
|
The positional arguments to pass into each function call.
|
||||||
kwargs (optional):
|
kwargs (optional):
|
||||||
The keyword-arguments to pass into the function call.
|
The keyword-arguments to pass into each function call.
|
||||||
num (optional):
|
num (optional):
|
||||||
The number of tasks to spawn with the specified parameters.
|
The number of tasks to spawn with the specified parameters.
|
||||||
end_callback (optional):
|
end_callback (optional):
|
||||||
A callback to execute after the task has ended.
|
A callback to execute after each task has ended.
|
||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
cancel_callback (optional):
|
cancel_callback (optional):
|
||||||
A callback to execute after cancellation of the task.
|
A callback to execute after cancellation of each task.
|
||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
"""
|
"""
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
await gather(*(self._start_task(func(*args, **kwargs), group_name=group_name, end_callback=end_callback,
|
for i in range(num):
|
||||||
cancel_callback=cancel_callback) for _ in range(num)))
|
try:
|
||||||
|
coroutine = func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# This means there was probably something wrong with the function arguments.
|
||||||
|
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
|
str(e.__class__.__name__), group_name, func.__name__, repr(args), repr(kwargs))
|
||||||
|
continue # TODO: Consider returning instead of continuing
|
||||||
|
try:
|
||||||
|
await self._start_task(coroutine, group_name=group_name, end_callback=end_callback,
|
||||||
|
cancel_callback=cancel_callback)
|
||||||
|
except CancelledError:
|
||||||
|
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||||
|
# more tasks and can return immediately.
|
||||||
|
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
||||||
|
coroutine.close()
|
||||||
|
return
|
||||||
|
|
||||||
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
|
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1, group_name: str = None,
|
||||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
Creates tasks with the supplied arguments to be run in the pool.
|
Creates tasks with the supplied arguments to be run in the pool.
|
||||||
|
|
||||||
@ -684,7 +653,12 @@ class TaskPool(BaseTaskPool):
|
|||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
This method blocks, **only if** the pool has not enough room to accommodate `num` new tasks.
|
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||||
|
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||||
|
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
|
||||||
|
|
||||||
|
If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first,
|
||||||
|
the number of tasks spawned will end up being less than `num`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func:
|
func:
|
||||||
@ -694,9 +668,11 @@ class TaskPool(BaseTaskPool):
|
|||||||
kwargs (optional):
|
kwargs (optional):
|
||||||
The keyword-arguments to pass into each function call.
|
The keyword-arguments to pass into each function call.
|
||||||
num (optional):
|
num (optional):
|
||||||
The number of tasks to spawn with the specified parameters.
|
The number of tasks to spawn with the specified parameters. Defaults to 1.
|
||||||
group_name (optional):
|
group_name (optional):
|
||||||
Name of the task group to add the new tasks to.
|
Name of the task group to add the new tasks to. By default, a unique name is constructed in the form
|
||||||
|
:code:`'apply-{name}-group-{idx}'` (with `name` being the name of the `func` and `idx` being an
|
||||||
|
incrementing index).
|
||||||
end_callback (optional):
|
end_callback (optional):
|
||||||
A callback to execute after a task has ended.
|
A callback to execute after a task has ended.
|
||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
@ -705,75 +681,52 @@ class TaskPool(BaseTaskPool):
|
|||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The name of the task group that the newly spawned tasks have been added to.
|
The name of the newly created group (see the `group_name` parameter).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
`PoolIsClosed`: The pool is closed.
|
`PoolIsClosed`: The pool is closed.
|
||||||
`NotCoroutine`: `func` is not a coroutine function.
|
`NotCoroutine`: `func` is not a coroutine function.
|
||||||
`PoolIsLocked`: The pool is currently locked.
|
`PoolIsLocked`: The pool is currently locked.
|
||||||
|
`InvalidGroupName`: A group named `group_name` exists in the pool.
|
||||||
"""
|
"""
|
||||||
self._check_start(function=func)
|
self._check_start(function=func)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
group_name = self._generate_group_name('apply', func)
|
group_name = self._generate_group_name('apply', func)
|
||||||
group_reg = self._task_groups.setdefault(group_name, TaskGroupRegister())
|
if group_name in self._task_groups.keys():
|
||||||
async with group_reg:
|
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
|
||||||
task = create_task(self._apply_num(group_name, func, args, kwargs, num, end_callback, cancel_callback))
|
self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||||
await task
|
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||||
|
meta_tasks.add(create_task(self._apply_spawner(group_name, func, args, kwargs, num,
|
||||||
|
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||||
return group_name
|
return group_name
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _queue_producer(cls, arg_queue: Queue, arg_iter: Iterator[Any], group_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Keeps the arguments queue from :meth:`_map` full as long as the iterator has elements.
|
|
||||||
|
|
||||||
Intended to be run as a meta task of a specific group.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arg_queue: The queue of function arguments to consume for starting a new task.
|
|
||||||
arg_iter: The iterator of function arguments to put into the queue.
|
|
||||||
group_name: Name of the task group associated with this producer.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
for arg in arg_iter:
|
|
||||||
await arg_queue.put(arg) # This blocks as long as the queue is full.
|
|
||||||
except CancelledError:
|
|
||||||
# This means that no more tasks are supposed to be created from this `_map()` call;
|
|
||||||
# thus, we can immediately drain the entire queue and forget about the rest of the arguments.
|
|
||||||
log.debug("Cancelled consumption of argument iterable in task group '%s'", group_name)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
arg_queue.get_nowait()
|
|
||||||
arg_queue.item_processed()
|
|
||||||
except QueueEmpty:
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
await arg_queue.put(cls._QUEUE_END_SENTINEL)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
|
def _get_map_end_callback(map_semaphore: Semaphore, actual_end_callback: EndCB) -> EndCB:
|
||||||
"""Returns a wrapped `end_callback` for each :meth:`_queue_consumer` task that releases the `map_semaphore`."""
|
"""Returns a wrapped `end_callback` for each :meth:`_arg_consumer` task that releases the `map_semaphore`."""
|
||||||
async def release_callback(task_id: int) -> None:
|
async def release_callback(task_id: int) -> None:
|
||||||
map_semaphore.release()
|
map_semaphore.release()
|
||||||
await execute_optional(actual_end_callback, args=(task_id,))
|
await execute_optional(actual_end_callback, args=(task_id,))
|
||||||
return release_callback
|
return release_callback
|
||||||
|
|
||||||
async def _queue_consumer(self, arg_queue: Queue, group_name: str, func: CoroutineFunc, arg_stars: int = 0,
|
async def _arg_consumer(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
arg_stars: int, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||||
"""
|
"""
|
||||||
Consumes arguments from the queue from :meth:`_map` and keeps a limited number of tasks working on them.
|
Consumes arguments from :meth:`_map` and keeps a limited number of tasks working on them.
|
||||||
|
|
||||||
The queue's maximum size is taken as the limiting value of an internal semaphore, which must be acquired before
|
`num_concurrent` acts as the limiting value of an internal semaphore, which must be acquired before a new task
|
||||||
a new task can be started, and which must be released when one of these tasks ends.
|
can be started, and which must be released when one of these tasks ends.
|
||||||
|
|
||||||
Intended to be run as a meta task of a specific group.
|
Intended to be run as a meta task of a specific group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arg_queue:
|
|
||||||
The queue of function arguments to consume for starting a new task.
|
|
||||||
group_name:
|
group_name:
|
||||||
Name of the associated task group; passed into :meth:`_start_task`.
|
Name of the associated task group; passed into :meth:`_start_task`.
|
||||||
|
num_concurrent:
|
||||||
|
The maximum number new tasks spawned by this method to run concurrently.
|
||||||
func:
|
func:
|
||||||
The coroutine function to use for spawning the new tasks within the task pool.
|
The coroutine function to use for spawning the new tasks within the task pool.
|
||||||
|
arg_iter:
|
||||||
|
The iterable of arguments; each element is to be passed into a `func` call when spawning a new task.
|
||||||
arg_stars (optional):
|
arg_stars (optional):
|
||||||
Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2.
|
Whether or not to unpack an element from `arg_queue` using stars; must be 0, 1, or 2.
|
||||||
end_callback (optional):
|
end_callback (optional):
|
||||||
@ -783,53 +736,60 @@ class TaskPool(BaseTaskPool):
|
|||||||
The callback that was specified to execute after cancellation of the task (and the next one).
|
The callback that was specified to execute after cancellation of the task (and the next one).
|
||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
"""
|
"""
|
||||||
map_semaphore = Semaphore(arg_queue.maxsize) # value determined by `group_size` in :meth:`_map`
|
semaphore = Semaphore(num_concurrent)
|
||||||
release_cb = self._get_map_end_callback(map_semaphore, actual_end_callback=end_callback)
|
release_cb = self._get_map_end_callback(semaphore, actual_end_callback=end_callback)
|
||||||
while True:
|
for i, next_arg in enumerate(arg_iter):
|
||||||
# The following line blocks **only if** the number of running tasks spawned by this method has reached the
|
semaphore_acquired = False
|
||||||
# specified maximum as determined in :meth:`_map`.
|
try:
|
||||||
await map_semaphore.acquire()
|
coroutine = star_function(func, next_arg, arg_stars=arg_stars)
|
||||||
# We await the queue's `get()` coroutine and subsequently ensure that its `task_done()` method is called.
|
except Exception as e:
|
||||||
async with arg_queue as next_arg:
|
# This means there was probably something wrong with the function arguments.
|
||||||
if next_arg is self._QUEUE_END_SENTINEL:
|
log.exception("%s occurred in group '%s' while trying to create coroutine: %s(%s%s)",
|
||||||
# The :meth:`_queue_producer` either reached the last argument or was cancelled.
|
str(e.__class__.__name__), group_name, func.__name__, '*' * arg_stars, str(next_arg))
|
||||||
return
|
continue
|
||||||
try:
|
try:
|
||||||
await self._start_task(star_function(func, next_arg, arg_stars=arg_stars), group_name=group_name,
|
# When the number of running tasks spawned by this method reaches the specified maximum,
|
||||||
ignore_lock=True, end_callback=release_cb, cancel_callback=cancel_callback)
|
# this next line will block, until one of them ends and releases the semaphore.
|
||||||
except Exception as e:
|
semaphore_acquired = await semaphore.acquire()
|
||||||
# This means an exception occurred during task **creation**, meaning no task has been created.
|
await self._start_task(coroutine, group_name=group_name, ignore_lock=True,
|
||||||
# It does not imply an error within the task itself.
|
end_callback=release_cb, cancel_callback=cancel_callback)
|
||||||
log.exception("%s occurred while trying to create task: %s(%s%s)",
|
except CancelledError:
|
||||||
str(e.__class__.__name__), func.__name__, '*' * arg_stars, str(next_arg))
|
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||||
map_semaphore.release()
|
# more tasks and can return immediately. (This means we drop `arg_iter` without consuming it fully.)
|
||||||
|
log.debug("Cancelled group '%s' after %s tasks have been spawned", group_name, i)
|
||||||
|
coroutine.close()
|
||||||
|
if semaphore_acquired:
|
||||||
|
semaphore.release()
|
||||||
|
return
|
||||||
|
|
||||||
async def _map(self, group_name: str, group_size: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
def _map(self, group_name: str, num_concurrent: int, func: CoroutineFunc, arg_iter: ArgsT, arg_stars: int,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> None:
|
||||||
"""
|
"""
|
||||||
Creates tasks in the pool with arguments from the supplied iterable.
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
Each coroutine looks like `func(arg)`, `func(*arg)`, or `func(**arg)`, `arg` being taken from `arg_iter`.
|
||||||
|
The method is a task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
`num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
|
||||||
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
|
any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
|
||||||
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
||||||
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
|
of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
|
||||||
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
|
||||||
|
|
||||||
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
|
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||||
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
|
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||||
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
|
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
|
||||||
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
|
|
||||||
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
|
If the entire task group is cancelled, the meta task is cancelled first, which means that `arg_iter` may be
|
||||||
|
abandoned before being fully consumed (if that is even possible).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
group_name:
|
group_name:
|
||||||
Name of the task group to add the new tasks to. It must be a name that doesn't exist yet.
|
Name of the task group to add the new tasks to. It must be a name that doesn't exist yet.
|
||||||
group_size:
|
num_concurrent:
|
||||||
The maximum number new tasks spawned by this method to run concurrently.
|
The number new tasks spawned by this method to run concurrently.
|
||||||
func:
|
func:
|
||||||
The coroutine function to use for spawning the new tasks within the task pool.
|
The coroutine function to use for spawning the new tasks within the task pool.
|
||||||
arg_iter:
|
arg_iter:
|
||||||
@ -844,62 +804,53 @@ class TaskPool(BaseTaskPool):
|
|||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
`ValueError`: `group_size` is less than 1.
|
`ValueError`: `num_concurrent` is less than 1.
|
||||||
`asyncio_taskpool.exceptions.InvalidGroupName`: A group named `group_name` exists in the pool.
|
`asyncio_taskpool.exceptions.InvalidGroupName`: A group named `group_name` exists in the pool.
|
||||||
"""
|
"""
|
||||||
self._check_start(function=func)
|
self._check_start(function=func)
|
||||||
if group_size < 1:
|
if num_concurrent < 1:
|
||||||
raise ValueError(f"Group size must be a positive integer.")
|
raise ValueError("`num_concurrent` must be a positive integer.")
|
||||||
if group_name in self._task_groups.keys():
|
if group_name in self._task_groups.keys():
|
||||||
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
|
raise exceptions.InvalidGroupName(f"Group named {group_name} already exists!")
|
||||||
self._task_groups[group_name] = group_reg = TaskGroupRegister()
|
self._task_groups[group_name] = TaskGroupRegister()
|
||||||
async with group_reg:
|
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||||
# Set up internal arguments queue. We limit its maximum size to enable lazy consumption of `arg_iter` by the
|
meta_tasks.add(create_task(self._arg_consumer(group_name, num_concurrent, func, arg_iter, arg_stars,
|
||||||
# `_queue_producer()`; that way an argument
|
end_callback=end_callback, cancel_callback=cancel_callback)))
|
||||||
arg_queue = Queue(maxsize=group_size)
|
|
||||||
# TODO: This is the wrong thing to await before gathering!
|
|
||||||
# Since the queue producer and consumer operate in separate tasks, it is possible that the consumer
|
|
||||||
# "finishes" the entire queue before the producer manages to put more items in it, thus returning
|
|
||||||
# the `join` call before the arguments iterator was fully consumed.
|
|
||||||
# Probably the queue producer task should be awaited before gathering instead.
|
|
||||||
self._before_gathering.append(join_queue(arg_queue))
|
|
||||||
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
|
||||||
# Start the producer and consumer meta tasks.
|
|
||||||
meta_tasks.add(create_task(self._queue_producer(arg_queue, iter(arg_iter), group_name)))
|
|
||||||
meta_tasks.add(create_task(self._queue_consumer(arg_queue, group_name, func, arg_stars,
|
|
||||||
end_callback, cancel_callback)))
|
|
||||||
|
|
||||||
async def map(self, func: CoroutineFunc, arg_iter: ArgsT, group_size: int = 1, group_name: str = None,
|
def map(self, func: CoroutineFunc, arg_iter: ArgsT, num_concurrent: int = 1, group_name: str = None,
|
||||||
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
A task-based equivalent of the `multiprocessing.pool.Pool.map` method.
|
|
||||||
|
|
||||||
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`.
|
|
||||||
|
Each coroutine looks like `func(arg)`, `arg` being an element taken from `arg_iter`. The method is a task-based
|
||||||
|
equivalent of the `multiprocessing.pool.Pool.map` method.
|
||||||
|
|
||||||
All the new tasks are added to the same task group.
|
All the new tasks are added to the same task group.
|
||||||
|
|
||||||
The `group_size` determines the maximum number of tasks spawned this way that shall be running concurrently at
|
`num_concurrent` determines the (maximum) number of tasks spawned this way that shall be running concurrently at
|
||||||
any given moment in time. As soon as one task from this group ends, it triggers the start of a new task
|
any given moment in time. As soon as one task from this method call ends, it triggers the start of a new task
|
||||||
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
(assuming there is room in the pool), which consumes the next element from the arguments iterable. If the size
|
||||||
of the pool never imposes a limit, this ensures that the number of tasks belonging to this group and running
|
of the pool never imposes a limit, this ensures that the number of tasks spawned and running concurrently is
|
||||||
concurrently is always equal to `group_size` (except for when `arg_iter` is exhausted of course).
|
always equal to `num_concurrent` (except for when `arg_iter` is exhausted of course).
|
||||||
|
|
||||||
This method sets up an internal arguments queue which is continuously filled while consuming the `arg_iter`.
|
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||||
Because this method delegates the spawning of the tasks to two meta tasks (a producer and a consumer of the
|
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||||
aforementioned queue), it **never blocks**. However, just because this method returns immediately, this does
|
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num_concurrent`.
|
||||||
not mean that any task was started or that any number of tasks will start soon, as this is solely determined by
|
|
||||||
the :attr:`BaseTaskPool.pool_size` and the `group_size`.
|
If the entire task group is cancelled, the meta task is cancelled first, which means that `arg_iter` may be
|
||||||
|
abandoned before being fully consumed (if that is even possible).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func:
|
func:
|
||||||
The coroutine function to use for spawning the new tasks within the task pool.
|
The coroutine function to use for spawning the new tasks within the task pool.
|
||||||
arg_iter:
|
arg_iter:
|
||||||
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
|
The iterable of arguments; each argument is to be passed into a `func` call when spawning a new task.
|
||||||
group_size (optional):
|
num_concurrent (optional):
|
||||||
The maximum number new tasks spawned by this method to run concurrently. Defaults to 1.
|
The number new tasks spawned by this method to run concurrently. Defaults to 1.
|
||||||
group_name (optional):
|
group_name (optional):
|
||||||
Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
|
Name of the task group to add the new tasks to. If provided, it must be a name that doesn't exist yet.
|
||||||
|
By default, a unique name is constructed in the form :code:`'map-{name}-group-{idx}'`
|
||||||
|
(with `name` being the name of the `func` and `idx` being an incrementing index).
|
||||||
end_callback (optional):
|
end_callback (optional):
|
||||||
A callback to execute after a task has ended.
|
A callback to execute after a task has ended.
|
||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
@ -908,46 +859,57 @@ class TaskPool(BaseTaskPool):
|
|||||||
It is run with the task's ID as its only positional argument.
|
It is run with the task's ID as its only positional argument.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The name of the task group that the newly spawned tasks will be added to.
|
The name of the newly created group (see the `group_name` parameter).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
`PoolIsClosed`: The pool is closed.
|
`PoolIsClosed`: The pool is closed.
|
||||||
`NotCoroutine`: `func` is not a coroutine function.
|
`NotCoroutine`: `func` is not a coroutine function.
|
||||||
`PoolIsLocked`: The pool is currently locked.
|
`PoolIsLocked`: The pool is currently locked.
|
||||||
`ValueError`: `group_size` is less than 1.
|
`ValueError`: `num_concurrent` is less than 1.
|
||||||
`InvalidGroupName`: A group named `group_name` exists in the pool.
|
`InvalidGroupName`: A group named `group_name` exists in the pool.
|
||||||
"""
|
"""
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
group_name = self._generate_group_name('map', func)
|
group_name = self._generate_group_name('map', func)
|
||||||
await self._map(group_name, group_size, func, arg_iter, 0,
|
self._map(group_name, num_concurrent, func, arg_iter, 0,
|
||||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||||
return group_name
|
return group_name
|
||||||
|
|
||||||
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], group_size: int = 1,
|
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_concurrent: int = 1, group_name: str = None,
|
||||||
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
"""
|
"""
|
||||||
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
Like :meth:`map` except that the elements of `args_iter` are expected to be iterables themselves to be unpacked
|
||||||
as positional arguments to the function.
|
as positional arguments to the function.
|
||||||
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
Each coroutine then looks like `func(*args)`, `args` being an element from `args_iter`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The name of the newly created group in the form :code:`'starmap-{name}-group-{index}'`
|
||||||
|
(with `name` being the name of the `func` and `idx` being an incrementing index).
|
||||||
"""
|
"""
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
group_name = self._generate_group_name('starmap', func)
|
group_name = self._generate_group_name('starmap', func)
|
||||||
await self._map(group_name, group_size, func, args_iter, 1,
|
self._map(group_name, num_concurrent, func, args_iter, 1,
|
||||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||||
return group_name
|
return group_name
|
||||||
|
|
||||||
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], group_size: int = 1,
|
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_concurrent: int = 1,
|
||||||
group_name: str = None, end_callback: EndCB = None,
|
group_name: str = None, end_callback: EndCB = None, cancel_callback: CancelCB = None) -> str:
|
||||||
cancel_callback: CancelCB = None) -> str:
|
|
||||||
"""
|
"""
|
||||||
|
Creates coroutines with arguments from the supplied iterable and runs them as new tasks in the pool.
|
||||||
|
|
||||||
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
Like :meth:`map` except that the elements of `kwargs_iter` are expected to be iterables themselves to be
|
||||||
unpacked as keyword-arguments to the function.
|
unpacked as keyword-arguments to the function.
|
||||||
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
Each coroutine then looks like `func(**kwargs)`, `kwargs` being an element from `kwargs_iter`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The name of the newly created group in the form :code:`'doublestarmap-{name}-group-{index}'`
|
||||||
|
(with `name` being the name of the `func` and `idx` being an incrementing index).
|
||||||
"""
|
"""
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
group_name = self._generate_group_name('doublestarmap', func)
|
group_name = self._generate_group_name('doublestarmap', func)
|
||||||
await self._map(group_name, group_size, func, kwargs_iter, 2,
|
self._map(group_name, num_concurrent, func, kwargs_iter, 2,
|
||||||
end_callback=end_callback, cancel_callback=cancel_callback)
|
end_callback=end_callback, cancel_callback=cancel_callback)
|
||||||
return group_name
|
return group_name
|
||||||
|
|
||||||
|
|
||||||
@ -1002,6 +964,7 @@ class SimpleTaskPool(BaseTaskPool):
|
|||||||
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
|
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
|
||||||
self._end_callback: EndCB = end_callback
|
self._end_callback: EndCB = end_callback
|
||||||
self._cancel_callback: CancelCB = cancel_callback
|
self._cancel_callback: CancelCB = cancel_callback
|
||||||
|
self._start_calls: int = 0
|
||||||
super().__init__(pool_size=pool_size, name=name)
|
super().__init__(pool_size=pool_size, name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1009,26 +972,52 @@ class SimpleTaskPool(BaseTaskPool):
|
|||||||
"""Name of the coroutine function used in the pool."""
|
"""Name of the coroutine function used in the pool."""
|
||||||
return self._func.__name__
|
return self._func.__name__
|
||||||
|
|
||||||
async def _start_one(self) -> int:
|
async def _start_num(self, num: int, group_name: str) -> None:
|
||||||
"""Starts a single new task within the pool and returns its ID."""
|
"""Starts `num` new tasks in group `group_name`."""
|
||||||
return await self._start_task(self._func(*self._args, **self._kwargs),
|
for i in range(num):
|
||||||
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
|
try:
|
||||||
|
coroutine = self._func(*self._args, **self._kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# This means there was probably something wrong with the function arguments.
|
||||||
|
log.exception("%s occurred in '%s' while trying to create coroutine: %s(*%s, **%s)",
|
||||||
|
str(e.__class__.__name__), str(self), self._func.__name__,
|
||||||
|
repr(self._args), repr(self._kwargs))
|
||||||
|
continue # TODO: Consider returning instead of continuing
|
||||||
|
try:
|
||||||
|
await self._start_task(coroutine, group_name=group_name, end_callback=self._end_callback,
|
||||||
|
cancel_callback=self._cancel_callback)
|
||||||
|
except CancelledError:
|
||||||
|
# Either the task group or all tasks were cancelled, so this meta tasks is not supposed to spawn any
|
||||||
|
# more tasks and can return immediately.
|
||||||
|
log.debug("Cancelled group '%s' after %s out of %s tasks have been spawned", group_name, i, num)
|
||||||
|
coroutine.close()
|
||||||
|
return
|
||||||
|
|
||||||
async def start(self, num: int) -> List[int]:
|
def start(self, num: int) -> str:
|
||||||
"""
|
"""
|
||||||
Starts specified number of new tasks in the pool and returns their IDs.
|
Starts specified number of new tasks in the pool as a new group.
|
||||||
|
|
||||||
This method may block if there is less room in the pool than the desired number of new tasks.
|
Because this method delegates the spawning of the tasks to a meta task, it **never blocks**. However, just
|
||||||
|
because this method returns immediately, this does not mean that any task was started or that any number of
|
||||||
|
tasks will start soon, as this is solely determined by the :attr:`BaseTaskPool.pool_size` and `num`.
|
||||||
|
|
||||||
|
If the entire task group is cancelled before `num` tasks have spawned, since the meta task is cancelled first,
|
||||||
|
the number of tasks spawned will end up being less than `num`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num: The number of new tasks to start.
|
num: The number of new tasks to start.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of IDs of the new tasks that have been started (not necessarily in the order they were started).
|
The name of the newly created task group in the form :code:`'start-group-{idx}'`
|
||||||
|
(with `idx` being an incrementing index).
|
||||||
"""
|
"""
|
||||||
ids = await gather(*(self._start_one() for _ in range(num)))
|
self._check_start(function=self._func)
|
||||||
assert isinstance(ids, list) # for PyCharm
|
group_name = f'start-group-{self._start_calls}'
|
||||||
return ids
|
self._start_calls += 1
|
||||||
|
self._task_groups.setdefault(group_name, TaskGroupRegister())
|
||||||
|
meta_tasks = self._group_meta_tasks_running.setdefault(group_name, set())
|
||||||
|
meta_tasks.add(create_task(self._start_num(num, group_name)))
|
||||||
|
return group_name
|
||||||
|
|
||||||
def stop(self, num: int) -> List[int]:
|
def stop(self, num: int) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
30
tests/__main__.py
Normal file
30
tests/__main__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
__author__ = "Daniil Fajnberg"
|
||||||
|
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||||
|
__license__ = """GNU LGPLv3.0
|
||||||
|
|
||||||
|
This file is part of asyncio-taskpool.
|
||||||
|
|
||||||
|
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||||
|
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||||
|
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
See the GNU Lesser General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||||
|
If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
|
__doc__ = """
|
||||||
|
Main entry point for all unit tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_suite = unittest.defaultTestLoader.discover('.')
|
||||||
|
test_runner = unittest.TextTestRunner(resultclass=unittest.TextTestResult)
|
||||||
|
result = test_runner.run(test_suite)
|
||||||
|
sys.exit(not result.wasSuccessful())
|
@ -38,7 +38,7 @@ class CLITestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_client = MagicMock(start=mock_client_start)
|
mock_client = MagicMock(start=mock_client_start)
|
||||||
mock_client_cls = MagicMock(return_value=mock_client)
|
mock_client_cls = MagicMock(return_value=mock_client)
|
||||||
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
|
mock_client_kwargs = {'foo': 123, 'bar': 456, 'baz': 789}
|
||||||
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls} | mock_client_kwargs
|
mock_parse_cli.return_value = {module.CLIENT_CLASS: mock_client_cls, **mock_client_kwargs}
|
||||||
self.assertIsNone(await module.main())
|
self.assertIsNone(await module.main())
|
||||||
mock_parse_cli.assert_called_once_with()
|
mock_parse_cli.assert_called_once_with()
|
||||||
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
|
mock_client_cls.assert_called_once_with(**mock_client_kwargs)
|
||||||
|
@ -71,7 +71,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._server_handshake(self.mock_reader, self.mock_writer))
|
||||||
self.assertTrue(self.client._connected)
|
self.assertTrue(self.client._connected)
|
||||||
mock__client_info.assert_called_once_with()
|
mock__client_info.assert_called_once_with()
|
||||||
self.mock_write.assert_called_once_with(json.dumps(mock_info).encode())
|
self.mock_write.assert_has_calls([call(json.dumps(mock_info).encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_has_calls([
|
self.mock_print.assert_has_calls([
|
||||||
@ -121,7 +121,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
mock__get_command.return_value = cmd = FOO + BAR + ' 123'
|
||||||
self.mock_drain.side_effect = err = ConnectionError()
|
self.mock_drain.side_effect = err = ConnectionError()
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_called_once_with(cmd.encode())
|
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_not_awaited()
|
self.mock_read.assert_not_awaited()
|
||||||
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
self.mock_print.assert_called_once_with(err, file=sys.stderr)
|
||||||
@ -133,7 +133,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.mock_print.reset_mock()
|
self.mock_print.reset_mock()
|
||||||
|
|
||||||
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
self.assertIsNone(await self.client._interact(self.mock_reader, self.mock_writer))
|
||||||
self.mock_write.assert_called_once_with(cmd.encode())
|
self.mock_write.assert_has_calls([call(cmd.encode()), call(b'\n')])
|
||||||
self.mock_drain.assert_awaited_once_with()
|
self.mock_drain.assert_awaited_once_with()
|
||||||
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
self.mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
||||||
self.mock_print.assert_called_once_with(FOO)
|
self.mock_print.assert_called_once_with(FOO)
|
||||||
|
@ -35,15 +35,15 @@ from asyncio_taskpool.internals.types import ArgsT, CancelCB, CoroutineFunc, End
|
|||||||
FOO, BAR = 'foo', 'bar'
|
FOO, BAR = 'foo', 'bar'
|
||||||
|
|
||||||
|
|
||||||
class ControlServerTestCase(TestCase):
|
class ControlParserTestCase(TestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
|
self.help_formatter_factory_patcher = patch.object(parser.ControlParser, 'help_formatter_factory')
|
||||||
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
self.mock_help_formatter_factory = self.help_formatter_factory_patcher.start()
|
||||||
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
self.mock_help_formatter_factory.return_value = RawTextHelpFormatter
|
||||||
self.stream_writer, self.terminal_width = MagicMock(), 420
|
self.stream, self.terminal_width = MagicMock(), 420
|
||||||
self.kwargs = {
|
self.kwargs = {
|
||||||
'stream_writer': self.stream_writer,
|
'stream': self.stream,
|
||||||
'terminal_width': self.terminal_width,
|
'terminal_width': self.terminal_width,
|
||||||
'formatter_class': FOO
|
'formatter_class': FOO
|
||||||
}
|
}
|
||||||
@ -72,10 +72,9 @@ class ControlServerTestCase(TestCase):
|
|||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assertIsInstance(self.parser, ArgumentParser)
|
self.assertIsInstance(self.parser, ArgumentParser)
|
||||||
self.assertEqual(self.stream_writer, self.parser._stream_writer)
|
self.assertEqual(self.stream, self.parser._stream)
|
||||||
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
self.assertEqual(self.terminal_width, self.parser._terminal_width)
|
||||||
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
self.mock_help_formatter_factory.assert_called_once_with(self.terminal_width, FOO)
|
||||||
self.assertFalse(getattr(self.parser, 'exit_on_error'))
|
|
||||||
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
self.assertEqual(RawTextHelpFormatter, getattr(self.parser, 'formatter_class'))
|
||||||
self.assertSetEqual(set(), self.parser._flags)
|
self.assertSetEqual(set(), self.parser._flags)
|
||||||
self.assertIsNone(self.parser._commands)
|
self.assertIsNone(self.parser._commands)
|
||||||
@ -89,7 +88,7 @@ class ControlServerTestCase(TestCase):
|
|||||||
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
||||||
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
||||||
expected_name = 'foo-bar'
|
expected_name = 'foo-bar'
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
|
||||||
to_omit = ['abc', 'xyz']
|
to_omit = ['abc', 'xyz']
|
||||||
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
|
output = self.parser.add_function_command(foo_bar, omit_params=to_omit, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
@ -107,7 +106,7 @@ class ControlServerTestCase(TestCase):
|
|||||||
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
mock_get_first_doc_line.return_value = mock_help = 'help 123'
|
||||||
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
kwargs = {FOO: 1, BAR: 2, parser.DESCRIPTION: FOO + BAR}
|
||||||
expected_name = 'get-prop'
|
expected_name = 'get-prop'
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: mock_help, **kwargs}
|
||||||
output = self.parser.add_property_command(prop, **kwargs)
|
output = self.parser.add_property_command(prop, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
mock_get_first_doc_line.assert_called_once_with(get_prop)
|
mock_get_first_doc_line.assert_called_once_with(get_prop)
|
||||||
@ -119,7 +118,7 @@ class ControlServerTestCase(TestCase):
|
|||||||
|
|
||||||
prop = property(get_prop, set_prop)
|
prop = property(get_prop, set_prop)
|
||||||
expected_help = f"Get/set the `.{expected_name}` property"
|
expected_help = f"Get/set the `.{expected_name}` property"
|
||||||
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help} | kwargs
|
expected_kwargs = {parser.NAME: expected_name, parser.PROG: expected_name, parser.HELP: expected_help, **kwargs}
|
||||||
output = self.parser.add_property_command(prop, **kwargs)
|
output = self.parser.add_property_command(prop, **kwargs)
|
||||||
self.assertEqual(mock_subparser, output)
|
self.assertEqual(mock_subparser, output)
|
||||||
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
|
mock_get_first_doc_line.assert_has_calls([call(get_prop), call(set_prop)])
|
||||||
@ -152,8 +151,7 @@ class ControlServerTestCase(TestCase):
|
|||||||
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
|
mock_subparser = MagicMock(set_defaults=mock_set_defaults)
|
||||||
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
|
mock_add_function_command.return_value = mock_add_property_command.return_value = mock_subparser
|
||||||
x = 'x'
|
x = 'x'
|
||||||
common_kwargs = {parser.STREAM_WRITER: self.parser._stream_writer,
|
common_kwargs = {'stream': self.parser._stream, parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
|
||||||
parser.CLIENT_INFO.TERMINAL_WIDTH: self.parser._terminal_width}
|
|
||||||
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
|
expected_output = {'method': mock_subparser, 'prop': mock_subparser}
|
||||||
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
|
output = self.parser.add_class_commands(FooBar, public_only=True, omit_members=['to_omit'], member_arg_name=x)
|
||||||
self.assertDictEqual(expected_output, output)
|
self.assertDictEqual(expected_output, output)
|
||||||
@ -170,12 +168,12 @@ class ControlServerTestCase(TestCase):
|
|||||||
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
|
mock_base_add_subparsers.assert_called_once_with(*args, **kwargs)
|
||||||
|
|
||||||
def test__print_message(self):
|
def test__print_message(self):
|
||||||
self.stream_writer.write = MagicMock()
|
self.stream.write = MagicMock()
|
||||||
self.assertIsNone(self.parser._print_message(''))
|
self.assertIsNone(self.parser._print_message(''))
|
||||||
self.stream_writer.write.assert_not_called()
|
self.stream.write.assert_not_called()
|
||||||
msg = 'foo bar baz'
|
msg = 'foo bar baz'
|
||||||
self.assertIsNone(self.parser._print_message(msg))
|
self.assertIsNone(self.parser._print_message(msg))
|
||||||
self.stream_writer.write.assert_called_once_with(msg.encode())
|
self.stream.write.assert_called_once_with(msg)
|
||||||
|
|
||||||
@patch.object(parser.ControlParser, '_print_message')
|
@patch.object(parser.ControlParser, '_print_message')
|
||||||
def test_exit(self, mock__print_message: MagicMock):
|
def test_exit(self, mock__print_message: MagicMock):
|
||||||
@ -265,12 +263,36 @@ class ControlServerTestCase(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class RestTestCase(TestCase):
|
class RestTestCase(TestCase):
|
||||||
|
log_lvl: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
cls.log_lvl = parser.log.level
|
||||||
|
parser.log.setLevel(999)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
parser.log.setLevel(cls.log_lvl)
|
||||||
|
|
||||||
def test__get_arg_type_wrapper(self):
|
def test__get_arg_type_wrapper(self):
|
||||||
type_wrap = parser._get_arg_type_wrapper(int)
|
type_wrap = parser._get_arg_type_wrapper(int)
|
||||||
self.assertEqual('int', type_wrap.__name__)
|
self.assertEqual('int', type_wrap.__name__)
|
||||||
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
|
self.assertEqual(SUPPRESS, type_wrap(SUPPRESS))
|
||||||
self.assertEqual(13, type_wrap('13'))
|
self.assertEqual(13, type_wrap('13'))
|
||||||
|
|
||||||
|
name = 'abcdef'
|
||||||
|
mock_type = MagicMock(side_effect=[parser.ArgumentTypeError, TypeError, ValueError, Exception], __name__=name)
|
||||||
|
type_wrap = parser._get_arg_type_wrapper(mock_type)
|
||||||
|
self.assertEqual(name, type_wrap.__name__)
|
||||||
|
with self.assertRaises(parser.ArgumentTypeError):
|
||||||
|
type_wrap(FOO)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
type_wrap(FOO)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
type_wrap(FOO)
|
||||||
|
with self.assertRaises(parser.ArgumentTypeError):
|
||||||
|
type_wrap(FOO)
|
||||||
|
|
||||||
@patch.object(parser, '_get_arg_type_wrapper')
|
@patch.object(parser, '_get_arg_type_wrapper')
|
||||||
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
|
def test__get_type_from_annotation(self, mock__get_arg_type_wrapper: MagicMock):
|
||||||
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR
|
mock__get_arg_type_wrapper.return_value = expected_output = FOO + BAR
|
||||||
|
@ -21,11 +21,12 @@ Unittests for the `asyncio_taskpool.session` module.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from argparse import ArgumentError, Namespace
|
from argparse import ArgumentError, Namespace
|
||||||
|
from io import StringIO
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||||
|
|
||||||
from asyncio_taskpool.control import session
|
from asyncio_taskpool.control import session
|
||||||
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD, SESSION_MSG_BYTES, STREAM_WRITER
|
from asyncio_taskpool.internals.constants import CLIENT_INFO, CMD
|
||||||
from asyncio_taskpool.exceptions import HelpRequested
|
from asyncio_taskpool.exceptions import HelpRequested
|
||||||
from asyncio_taskpool.pool import SimpleTaskPool
|
from asyncio_taskpool.pool import SimpleTaskPool
|
||||||
|
|
||||||
@ -61,18 +62,19 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(self.mock_reader, self.session._reader)
|
self.assertEqual(self.mock_reader, self.session._reader)
|
||||||
self.assertEqual(self.mock_writer, self.session._writer)
|
self.assertEqual(self.mock_writer, self.session._writer)
|
||||||
self.assertIsNone(self.session._parser)
|
self.assertIsNone(self.session._parser)
|
||||||
|
self.assertIsInstance(self.session._response_buffer, StringIO)
|
||||||
|
|
||||||
@patch.object(session, 'return_or_exception')
|
@patch.object(session, 'return_or_exception')
|
||||||
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
|
async def test__exec_method_and_respond(self, mock_return_or_exception: AsyncMock):
|
||||||
def method(self, arg1, arg2, *var_args, **rest): pass
|
def method(self, arg1, arg2, *var_args, **rest): pass
|
||||||
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
|
test_arg1, test_arg2, test_var_args, test_rest = 123, 'xyz', [0.1, 0.2, 0.3], {'aaa': 1, 'bbb': 11}
|
||||||
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args} | test_rest
|
kwargs = {'arg1': test_arg1, 'arg2': test_arg2, 'var_args': test_var_args}
|
||||||
mock_return_or_exception.return_value = None
|
mock_return_or_exception.return_value = None
|
||||||
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs))
|
self.assertIsNone(await self.session._exec_method_and_respond(method, **kwargs, **test_rest))
|
||||||
mock_return_or_exception.assert_awaited_once_with(
|
mock_return_or_exception.assert_awaited_once_with(
|
||||||
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
method, self.mock_pool, test_arg1, test_arg2, *test_var_args, **test_rest
|
||||||
)
|
)
|
||||||
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
|
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
|
||||||
|
|
||||||
@patch.object(session, 'return_or_exception')
|
@patch.object(session, 'return_or_exception')
|
||||||
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
|
async def test__exec_property_and_respond(self, mock_return_or_exception: AsyncMock):
|
||||||
@ -83,15 +85,16 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_return_or_exception.return_value = None
|
mock_return_or_exception.return_value = None
|
||||||
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
|
self.assertIsNone(await self.session._exec_property_and_respond(prop, **kwargs))
|
||||||
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
|
mock_return_or_exception.assert_awaited_once_with(prop_set, self.mock_pool, **kwargs)
|
||||||
self.mock_writer.write.assert_called_once_with(session.CMD_OK)
|
self.assertEqual(session.CMD_OK.decode(), self.session._response_buffer.getvalue())
|
||||||
|
|
||||||
mock_return_or_exception.reset_mock()
|
mock_return_or_exception.reset_mock()
|
||||||
self.mock_writer.write.reset_mock()
|
self.session._response_buffer.seek(0)
|
||||||
|
self.session._response_buffer.truncate()
|
||||||
|
|
||||||
mock_return_or_exception.return_value = val = 420.69
|
mock_return_or_exception.return_value = val = 420.69
|
||||||
self.assertIsNone(await self.session._exec_property_and_respond(prop))
|
self.assertIsNone(await self.session._exec_property_and_respond(prop))
|
||||||
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
|
mock_return_or_exception.assert_awaited_once_with(prop_get, self.mock_pool)
|
||||||
self.mock_writer.write.assert_called_once_with(str(val).encode())
|
self.assertEqual(str(val), self.session._response_buffer.getvalue())
|
||||||
|
|
||||||
@patch.object(session, 'ControlParser')
|
@patch.object(session, 'ControlParser')
|
||||||
async def test_client_handshake(self, mock_parser_cls: MagicMock):
|
async def test_client_handshake(self, mock_parser_cls: MagicMock):
|
||||||
@ -100,11 +103,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parser_cls.return_value = mock_parser
|
mock_parser_cls.return_value = mock_parser
|
||||||
width = 5678
|
width = 5678
|
||||||
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
msg = ' ' + json.dumps({CLIENT_INFO.TERMINAL_WIDTH: width, FOO: BAR}) + ' '
|
||||||
mock_read = AsyncMock(return_value=msg.encode())
|
mock_readline = AsyncMock(return_value=msg.encode())
|
||||||
self.mock_reader.read = mock_read
|
self.mock_reader.readline = mock_readline
|
||||||
self.mock_writer.drain = AsyncMock()
|
self.mock_writer.drain = AsyncMock()
|
||||||
expected_parser_kwargs = {
|
expected_parser_kwargs = {
|
||||||
STREAM_WRITER: self.mock_writer,
|
'stream': self.session._response_buffer,
|
||||||
CLIENT_INFO.TERMINAL_WIDTH: width,
|
CLIENT_INFO.TERMINAL_WIDTH: width,
|
||||||
'prog': '',
|
'prog': '',
|
||||||
'usage': f'[-h] [{CMD}] ...'
|
'usage': f'[-h] [{CMD}] ...'
|
||||||
@ -115,11 +118,11 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
}
|
}
|
||||||
self.assertIsNone(await self.session.client_handshake())
|
self.assertIsNone(await self.session.client_handshake())
|
||||||
self.assertEqual(mock_parser, self.session._parser)
|
self.assertEqual(mock_parser, self.session._parser)
|
||||||
mock_read.assert_awaited_once_with(SESSION_MSG_BYTES)
|
mock_readline.assert_awaited_once_with()
|
||||||
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
mock_parser_cls.assert_called_once_with(**expected_parser_kwargs)
|
||||||
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
mock_add_subparsers.assert_called_once_with(**expected_subparsers_kwargs)
|
||||||
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
mock_add_class_commands.assert_called_once_with(self.mock_pool.__class__)
|
||||||
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode())
|
self.mock_writer.write.assert_called_once_with(str(self.mock_pool).encode() + b'\n')
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
@patch.object(session.ControlSession, '_exec_property_and_respond')
|
||||||
@ -132,10 +135,9 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
kwargs = {FOO: BAR, 'hello': 'python'}
|
kwargs = {FOO: BAR, 'hello': 'python'}
|
||||||
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
|
mock_parse_args = MagicMock(return_value=Namespace(**{CMD: method}, **kwargs))
|
||||||
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
self.session._parser = MagicMock(parse_args=mock_parse_args)
|
||||||
self.mock_writer.write = MagicMock()
|
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
|
mock__exec_method_and_respond.assert_awaited_once_with(method, **kwargs)
|
||||||
mock__exec_property_and_respond.assert_not_called()
|
mock__exec_property_and_respond.assert_not_called()
|
||||||
|
|
||||||
@ -145,7 +147,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
|
mock_parse_args.return_value = Namespace(**{CMD: prop}, **kwargs)
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_called()
|
mock__exec_method_and_respond.assert_not_called()
|
||||||
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
|
mock__exec_property_and_respond.assert_awaited_once_with(prop, **kwargs)
|
||||||
|
|
||||||
@ -161,47 +163,55 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
|||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
mock__exec_method_and_respond.assert_not_called()
|
mock__exec_method_and_respond.assert_not_called()
|
||||||
mock__exec_property_and_respond.assert_not_called()
|
mock__exec_property_and_respond.assert_not_called()
|
||||||
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
|
||||||
|
|
||||||
mock__exec_property_and_respond.reset_mock()
|
mock__exec_property_and_respond.reset_mock()
|
||||||
mock_parse_args.reset_mock()
|
mock_parse_args.reset_mock()
|
||||||
self.mock_writer.write.reset_mock()
|
self.session._response_buffer.seek(0)
|
||||||
|
self.session._response_buffer.truncate()
|
||||||
|
|
||||||
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
mock_parse_args.side_effect = exc = ArgumentError(MagicMock(), "oops")
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_called_once_with(str(exc).encode())
|
self.assertEqual(str(exc), self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_awaited()
|
mock__exec_method_and_respond.assert_not_awaited()
|
||||||
mock__exec_property_and_respond.assert_not_awaited()
|
mock__exec_property_and_respond.assert_not_awaited()
|
||||||
|
|
||||||
self.mock_writer.write.reset_mock()
|
|
||||||
mock_parse_args.reset_mock()
|
mock_parse_args.reset_mock()
|
||||||
|
self.session._response_buffer.seek(0)
|
||||||
|
self.session._response_buffer.truncate()
|
||||||
|
|
||||||
mock_parse_args.side_effect = HelpRequested()
|
mock_parse_args.side_effect = HelpRequested()
|
||||||
self.assertIsNone(await self.session._parse_command(msg))
|
self.assertIsNone(await self.session._parse_command(msg))
|
||||||
mock_parse_args.assert_called_once_with(msg.split(' '))
|
mock_parse_args.assert_called_once_with(msg.split(' '))
|
||||||
self.mock_writer.write.assert_not_called()
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
mock__exec_method_and_respond.assert_not_awaited()
|
mock__exec_method_and_respond.assert_not_awaited()
|
||||||
mock__exec_property_and_respond.assert_not_awaited()
|
mock__exec_property_and_respond.assert_not_awaited()
|
||||||
|
|
||||||
@patch.object(session.ControlSession, '_parse_command')
|
@patch.object(session.ControlSession, '_parse_command')
|
||||||
async def test_listen(self, mock__parse_command: AsyncMock):
|
async def test_listen(self, mock__parse_command: AsyncMock):
|
||||||
def make_reader_return_empty():
|
def make_reader_return_empty():
|
||||||
self.mock_reader.read.return_value = b''
|
self.mock_reader.readline.return_value = b''
|
||||||
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
self.mock_writer.drain = AsyncMock(side_effect=make_reader_return_empty)
|
||||||
msg = "fascinating"
|
msg = "fascinating"
|
||||||
self.mock_reader.read = AsyncMock(return_value=f' {msg} '.encode())
|
self.mock_reader.readline = AsyncMock(return_value=f' {msg} '.encode())
|
||||||
|
response = FOO + BAR + FOO
|
||||||
|
self.session._response_buffer.write(response)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.read.assert_has_awaits([call(SESSION_MSG_BYTES), call(SESSION_MSG_BYTES)])
|
self.mock_reader.readline.assert_has_awaits([call(), call()])
|
||||||
mock__parse_command.assert_awaited_once_with(msg)
|
mock__parse_command.assert_awaited_once_with(msg)
|
||||||
|
self.assertEqual('', self.session._response_buffer.getvalue())
|
||||||
|
self.mock_writer.write.assert_called_once_with(response.encode() + b'\n')
|
||||||
self.mock_writer.drain.assert_awaited_once_with()
|
self.mock_writer.drain.assert_awaited_once_with()
|
||||||
|
|
||||||
self.mock_reader.read.reset_mock()
|
self.mock_reader.readline.reset_mock()
|
||||||
mock__parse_command.reset_mock()
|
mock__parse_command.reset_mock()
|
||||||
|
self.mock_writer.write.reset_mock()
|
||||||
self.mock_writer.drain.reset_mock()
|
self.mock_writer.drain.reset_mock()
|
||||||
|
|
||||||
self.mock_server.is_serving = MagicMock(return_value=False)
|
self.mock_server.is_serving = MagicMock(return_value=False)
|
||||||
self.assertIsNone(await self.session.listen())
|
self.assertIsNone(await self.session.listen())
|
||||||
self.mock_reader.read.assert_not_awaited()
|
self.mock_reader.readline.assert_not_awaited()
|
||||||
mock__parse_command.assert_not_awaited()
|
mock__parse_command.assert_not_awaited()
|
||||||
|
self.mock_writer.write.assert_not_called()
|
||||||
self.mock_writer.drain.assert_not_awaited()
|
self.mock_writer.drain.assert_not_awaited()
|
||||||
|
@ -18,10 +18,11 @@ __doc__ = """
|
|||||||
Unittests for the `asyncio_taskpool.helpers` module.
|
Unittests for the `asyncio_taskpool.helpers` module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase, TestCase
|
||||||
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
from unittest.mock import MagicMock, AsyncMock, NonCallableMagicMock, call, patch
|
||||||
|
|
||||||
|
from asyncio_taskpool.internals import constants
|
||||||
from asyncio_taskpool.internals import helpers
|
from asyncio_taskpool.internals import helpers
|
||||||
|
|
||||||
|
|
||||||
@ -81,12 +82,6 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
helpers.star_function(f, a, 123456789)
|
helpers.star_function(f, a, 123456789)
|
||||||
|
|
||||||
async def test_join_queue(self):
|
|
||||||
mock_join = AsyncMock()
|
|
||||||
mock_queue = MagicMock(join=mock_join)
|
|
||||||
self.assertIsNone(await helpers.join_queue(mock_queue))
|
|
||||||
mock_join.assert_awaited_once_with()
|
|
||||||
|
|
||||||
def test_get_first_doc_line(self):
|
def test_get_first_doc_line(self):
|
||||||
expected_output = 'foo bar baz'
|
expected_output = 'foo bar baz'
|
||||||
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
mock_obj = MagicMock(__doc__=f"""{expected_output}
|
||||||
@ -128,3 +123,45 @@ class HelpersTestCase(IsolatedAsyncioTestCase):
|
|||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
helpers.resolve_dotted_path('foo.bar.baz')
|
helpers.resolve_dotted_path('foo.bar.baz')
|
||||||
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
|
mock_import_module.assert_has_calls([call('foo'), call('foo.bar')])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassMethodWorkaroundTestCase(TestCase):
|
||||||
|
def test_init(self):
|
||||||
|
def func(): return 'foo'
|
||||||
|
def getter(): return 'bar'
|
||||||
|
prop = property(getter)
|
||||||
|
instance = helpers.ClassMethodWorkaround(func)
|
||||||
|
self.assertIs(func, instance._getter)
|
||||||
|
instance = helpers.ClassMethodWorkaround(prop)
|
||||||
|
self.assertIs(getter, instance._getter)
|
||||||
|
|
||||||
|
@patch.object(helpers.ClassMethodWorkaround, '__init__', return_value=None)
|
||||||
|
def test_get(self, _mock_init: MagicMock):
|
||||||
|
def func(x: MagicMock): return x.__name__
|
||||||
|
instance = helpers.ClassMethodWorkaround(MagicMock())
|
||||||
|
instance._getter = func
|
||||||
|
obj, cls = None, MagicMock
|
||||||
|
expected_output = 'MagicMock'
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
obj = MagicMock(__name__='bar')
|
||||||
|
expected_output = 'bar'
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
cls = None
|
||||||
|
output = instance.__get__(obj, cls)
|
||||||
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
|
def test_correct_class(self):
|
||||||
|
is_older_python = constants.PYTHON_BEFORE_39
|
||||||
|
try:
|
||||||
|
constants.PYTHON_BEFORE_39 = True
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(helpers.ClassMethodWorkaround, helpers.classmethod)
|
||||||
|
constants.PYTHON_BEFORE_39 = False
|
||||||
|
importlib.reload(helpers)
|
||||||
|
self.assertIs(classmethod, helpers.classmethod)
|
||||||
|
finally:
|
||||||
|
constants.PYTHON_BEFORE_39 = is_older_python
|
||||||
|
@ -19,15 +19,12 @@ Unittests for the `asyncio_taskpool.pool` module.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from asyncio.queues import QueueEmpty
|
|
||||||
from datetime import datetime
|
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from asyncio_taskpool import pool, exceptions
|
from asyncio_taskpool import pool, exceptions
|
||||||
from asyncio_taskpool.internals.constants import DATETIME_FORMAT
|
|
||||||
|
|
||||||
|
|
||||||
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
|
EMPTY_LIST, EMPTY_DICT, EMPTY_SET = [], {}, set()
|
||||||
@ -86,17 +83,20 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(0, self.task_pool._num_started)
|
self.assertEqual(0, self.task_pool._num_started)
|
||||||
|
|
||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
self.assertFalse(self.task_pool._closed)
|
self.assertIsInstance(self.task_pool._closed, Event)
|
||||||
|
self.assertFalse(self.task_pool._closed.is_set())
|
||||||
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
||||||
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
|
|
||||||
self.assertListEqual(self.task_pool._before_gathering, EMPTY_LIST)
|
|
||||||
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
|
self.assertIsInstance(self.task_pool._enough_room, Semaphore)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
|
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||||
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
|
||||||
self.assertEqual(self.mock_idx, self.task_pool._idx)
|
self.assertEqual(self.mock_idx, self.task_pool._idx)
|
||||||
|
|
||||||
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
self.mock__add_pool.assert_called_once_with(self.task_pool)
|
||||||
@ -163,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool.get_group_ids(group_name, 'something else')
|
self.task_pool.get_group_ids(group_name, 'something else')
|
||||||
|
|
||||||
async def test__check_start(self):
|
async def test__check_start(self):
|
||||||
self.task_pool._closed = True
|
self.task_pool._closed.set()
|
||||||
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
||||||
try:
|
try:
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
@ -176,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
||||||
with self.assertRaises(exceptions.PoolIsClosed):
|
with self.assertRaises(exceptions.PoolIsClosed):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
||||||
self.task_pool._closed = False
|
self.task_pool._closed.clear()
|
||||||
self.task_pool._locked = True
|
self.task_pool._locked = True
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
with self.assertRaises(exceptions.PoolIsLocked):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
||||||
@ -312,104 +312,32 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._get_running_task(task_id)
|
self.task_pool._get_running_task(task_id)
|
||||||
mock__task_name.assert_not_called()
|
mock__task_name.assert_not_called()
|
||||||
|
|
||||||
|
@patch('warnings.warn')
|
||||||
|
def test__get_cancel_kw(self, mock_warn: MagicMock):
|
||||||
|
msg = None
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
msg = 'something'
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=True):
|
||||||
|
self.assertDictEqual(EMPTY_DICT, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_called_once()
|
||||||
|
mock_warn.reset_mock()
|
||||||
|
|
||||||
|
with patch.object(pool, 'PYTHON_BEFORE_39', new=False):
|
||||||
|
self.assertDictEqual({'msg': msg}, pool.BaseTaskPool._get_cancel_kw(msg))
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
@patch.object(pool.BaseTaskPool, '_get_running_task')
|
||||||
def test_cancel(self, mock__get_running_task: MagicMock):
|
def test_cancel(self, mock__get_running_task: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
task_id1, task_id2, task_id3 = 1, 4, 9
|
task_id1, task_id2, task_id3 = 1, 4, 9
|
||||||
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
mock__get_running_task.return_value.cancel = mock_cancel = MagicMock()
|
||||||
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
self.assertIsNone(self.task_pool.cancel(task_id1, task_id2, task_id3, msg=FOO))
|
||||||
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
mock__get_running_task.assert_has_calls([call(task_id1), call(task_id2), call(task_id3)])
|
||||||
mock_cancel.assert_has_calls([call(msg=FOO), call(msg=FOO), call(msg=FOO)])
|
mock__get_cancel_kw.assert_called_once_with(FOO)
|
||||||
|
mock_cancel.assert_has_calls(3 * [call(**fake_cancel_kw)])
|
||||||
def test__cancel_and_remove_all_from_group(self):
|
|
||||||
task_id = 555
|
|
||||||
mock_cancel = MagicMock()
|
|
||||||
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
|
|
||||||
|
|
||||||
class MockRegister(set, MagicMock):
|
|
||||||
pass
|
|
||||||
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), msg=FOO))
|
|
||||||
mock_cancel.assert_called_once_with(msg=FOO)
|
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
|
||||||
async def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
|
||||||
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
|
|
||||||
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
|
|
||||||
self.task_pool._task_groups[FOO] = mock_group_reg
|
|
||||||
with self.assertRaises(exceptions.InvalidGroupName):
|
|
||||||
await self.task_pool.cancel_group(BAR)
|
|
||||||
mock__cancel_and_remove_all_from_group.assert_not_called()
|
|
||||||
mock_grp_aenter.assert_not_called()
|
|
||||||
mock_grp_aexit.assert_not_called()
|
|
||||||
self.assertIsNone(await self.task_pool.cancel_group(FOO, msg=BAR))
|
|
||||||
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, msg=BAR)
|
|
||||||
mock_grp_aenter.assert_awaited_once_with()
|
|
||||||
mock_grp_aexit.assert_awaited_once()
|
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
|
||||||
async def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock):
|
|
||||||
mock_grp_aenter, mock_grp_aexit = AsyncMock(), AsyncMock()
|
|
||||||
mock_group_reg = MagicMock(__aenter__=mock_grp_aenter, __aexit__=mock_grp_aexit)
|
|
||||||
self.task_pool._task_groups[BAR] = mock_group_reg
|
|
||||||
self.assertIsNone(await self.task_pool.cancel_all(FOO))
|
|
||||||
mock__cancel_and_remove_all_from_group.assert_called_once_with(BAR, mock_group_reg, msg=FOO)
|
|
||||||
mock_grp_aenter.assert_awaited_once_with()
|
|
||||||
mock_grp_aexit.assert_awaited_once()
|
|
||||||
|
|
||||||
async def test_flush(self):
|
|
||||||
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
|
||||||
self.task_pool._tasks_ended = {123: mock_ended_func()}
|
|
||||||
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
|
|
||||||
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
|
|
||||||
mock_ended_func.assert_awaited_once_with()
|
|
||||||
mock_cancelled_func.assert_awaited_once_with()
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
|
||||||
|
|
||||||
async def test_gather_and_close(self):
|
|
||||||
mock_before_gather, mock_running_func = AsyncMock(), AsyncMock()
|
|
||||||
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
|
||||||
self.task_pool._before_gathering = before_gather = [mock_before_gather()]
|
|
||||||
self.task_pool._tasks_ended = ended = {123: mock_ended_func()}
|
|
||||||
self.task_pool._tasks_cancelled = cancelled = {456: mock_cancelled_func()}
|
|
||||||
self.task_pool._tasks_running = running = {789: mock_running_func()}
|
|
||||||
|
|
||||||
with self.assertRaises(exceptions.PoolStillUnlocked):
|
|
||||||
await self.task_pool.gather_and_close()
|
|
||||||
self.assertDictEqual(ended, self.task_pool._tasks_ended)
|
|
||||||
self.assertDictEqual(cancelled, self.task_pool._tasks_cancelled)
|
|
||||||
self.assertDictEqual(running, self.task_pool._tasks_running)
|
|
||||||
self.assertListEqual(before_gather, self.task_pool._before_gathering)
|
|
||||||
self.assertFalse(self.task_pool._closed)
|
|
||||||
|
|
||||||
self.task_pool._locked = True
|
|
||||||
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
|
||||||
mock_before_gather.assert_awaited_once_with()
|
|
||||||
mock_ended_func.assert_awaited_once_with()
|
|
||||||
mock_cancelled_func.assert_awaited_once_with()
|
|
||||||
mock_running_func.assert_awaited_once_with()
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
|
||||||
self.assertListEqual(EMPTY_LIST, self.task_pool._before_gathering)
|
|
||||||
self.assertTrue(self.task_pool._closed)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskPoolTestCase(CommonTestCase):
|
|
||||||
TEST_CLASS = pool.TaskPool
|
|
||||||
task_pool: pool.TaskPool
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
|
||||||
self.base_class_init = self.base_class_init_patcher.start()
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
self.base_class_init_patcher.stop()
|
|
||||||
super().tearDown()
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
|
||||||
self.base_class_init.assert_called_once_with(pool_size=self.TEST_POOL_SIZE, name=self.TEST_POOL_NAME)
|
|
||||||
|
|
||||||
def test__cancel_group_meta_tasks(self):
|
def test__cancel_group_meta_tasks(self):
|
||||||
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
mock_task1, mock_task2 = MagicMock(), MagicMock()
|
||||||
@ -426,26 +354,48 @@ class TaskPoolTestCase(CommonTestCase):
|
|||||||
mock_task1.cancel.assert_called_once_with()
|
mock_task1.cancel.assert_called_once_with()
|
||||||
mock_task2.cancel.assert_called_once_with()
|
mock_task2.cancel.assert_called_once_with()
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_cancel_group_meta_tasks')
|
||||||
|
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock):
|
||||||
|
kw = {BAR: 10, BAZ: 20}
|
||||||
|
task_id = 555
|
||||||
|
mock_cancel = MagicMock()
|
||||||
|
|
||||||
|
def add_mock_task_to_running(_):
|
||||||
|
self.task_pool._tasks_running[task_id] = MagicMock(cancel=mock_cancel)
|
||||||
|
# We add the fake task to the `_tasks_running` dictionary as a side effect of calling the mocked method,
|
||||||
|
# to verify that it is called first, before the cancellation loop starts.
|
||||||
|
mock__cancel_group_meta_tasks.side_effect = add_mock_task_to_running
|
||||||
|
|
||||||
|
class MockRegister(set, MagicMock):
|
||||||
|
pass
|
||||||
|
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(' ', MockRegister({task_id, 'x'}), **kw))
|
||||||
|
mock_cancel.assert_called_once_with(**kw)
|
||||||
|
|
||||||
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
@patch.object(pool.TaskPool, '_cancel_group_meta_tasks')
|
def test_cancel_group(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
def test__cancel_and_remove_all_from_group(self, mock__cancel_group_meta_tasks: MagicMock,
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
mock_base__cancel_and_remove_all_from_group: MagicMock):
|
self.task_pool._task_groups[FOO] = mock_group_reg = MagicMock()
|
||||||
group_name, group_reg, msg = 'xyz', MagicMock(), FOO
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
self.assertIsNone(self.task_pool._cancel_and_remove_all_from_group(group_name, group_reg, msg=msg))
|
self.task_pool.cancel_group(BAR)
|
||||||
mock__cancel_group_meta_tasks.assert_called_once_with(group_name)
|
mock__cancel_and_remove_all_from_group.assert_not_called()
|
||||||
mock_base__cancel_and_remove_all_from_group.assert_called_once_with(group_name, group_reg, msg=msg)
|
self.assertIsNone(self.task_pool.cancel_group(FOO, msg=BAR))
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._task_groups)
|
||||||
|
mock__get_cancel_kw.assert_called_once_with(BAR)
|
||||||
|
mock__cancel_and_remove_all_from_group.assert_called_once_with(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, 'cancel_group')
|
@patch.object(pool.BaseTaskPool, '_get_cancel_kw')
|
||||||
async def test_cancel_group(self, mock_base_cancel_group: AsyncMock):
|
@patch.object(pool.BaseTaskPool, '_cancel_and_remove_all_from_group')
|
||||||
group_name, msg = 'abc', 'xyz'
|
def test_cancel_all(self, mock__cancel_and_remove_all_from_group: MagicMock, mock__get_cancel_kw: MagicMock):
|
||||||
await self.task_pool.cancel_group(group_name, msg=msg)
|
mock__get_cancel_kw.return_value = fake_cancel_kw = {'a': 10, 'b': 20}
|
||||||
mock_base_cancel_group.assert_awaited_once_with(group_name=group_name, msg=msg)
|
mock_group_reg = MagicMock()
|
||||||
|
self.task_pool._task_groups = {FOO: mock_group_reg, BAR: mock_group_reg}
|
||||||
@patch.object(pool.BaseTaskPool, 'cancel_all')
|
self.assertIsNone(self.task_pool.cancel_all(BAZ))
|
||||||
async def test_cancel_all(self, mock_base_cancel_all: AsyncMock):
|
mock__get_cancel_kw.assert_called_once_with(BAZ)
|
||||||
msg = 'xyz'
|
mock__cancel_and_remove_all_from_group.assert_has_calls([
|
||||||
await self.task_pool.cancel_all(msg=msg)
|
call(BAR, mock_group_reg, **fake_cancel_kw),
|
||||||
mock_base_cancel_all.assert_awaited_once_with(msg=msg)
|
call(FOO, mock_group_reg, **fake_cancel_kw)
|
||||||
|
])
|
||||||
|
|
||||||
def test__pop_ended_meta_tasks(self):
|
def test__pop_ended_meta_tasks(self):
|
||||||
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
|
mock_task, mock_done_task1 = MagicMock(done=lambda: False), MagicMock(done=lambda: True)
|
||||||
@ -457,131 +407,162 @@ class TaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertSetEqual(expected_output, output)
|
self.assertSetEqual(expected_output, output)
|
||||||
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
|
self.assertDictEqual({FOO: {mock_task}}, self.task_pool._group_meta_tasks_running)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_pop_ended_meta_tasks')
|
@patch.object(pool.BaseTaskPool, '_pop_ended_meta_tasks')
|
||||||
@patch.object(pool.BaseTaskPool, 'flush')
|
async def test_flush(self, mock__pop_ended_meta_tasks: MagicMock):
|
||||||
async def test_flush(self, mock_base_flush: AsyncMock, mock__pop_ended_meta_tasks: MagicMock):
|
# Meta tasks:
|
||||||
mock_ended_meta_task = AsyncMock()
|
mock_ended_meta_task = AsyncMock()
|
||||||
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
|
mock__pop_ended_meta_tasks.return_value = {mock_ended_meta_task()}
|
||||||
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
||||||
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
||||||
self.assertIsNone(await self.task_pool.flush(return_exceptions=False))
|
# Actual tasks:
|
||||||
mock_base_flush.assert_awaited_once_with(return_exceptions=False)
|
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||||
|
self.task_pool._tasks_ended = {123: mock_ended_func()}
|
||||||
|
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
|
||||||
|
|
||||||
|
self.assertIsNone(await self.task_pool.flush(return_exceptions=True))
|
||||||
|
|
||||||
|
# Meta tasks:
|
||||||
mock__pop_ended_meta_tasks.assert_called_once_with()
|
mock__pop_ended_meta_tasks.assert_called_once_with()
|
||||||
mock_ended_meta_task.assert_awaited_once_with()
|
mock_ended_meta_task.assert_awaited_once_with()
|
||||||
mock_cancelled_meta_task.assert_awaited_once_with()
|
mock_cancelled_meta_task.assert_awaited_once_with()
|
||||||
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
# Actual tasks:
|
||||||
|
mock_ended_func.assert_awaited_once_with()
|
||||||
|
mock_cancelled_func.assert_awaited_once_with()
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
|
|
||||||
@patch.object(pool.BaseTaskPool, 'gather_and_close')
|
@patch.object(pool.BaseTaskPool, 'lock')
|
||||||
async def test_gather_and_close(self, mock_base_gather_and_close: AsyncMock):
|
async def test_gather_and_close(self, mock_lock: MagicMock):
|
||||||
|
# Meta tasks:
|
||||||
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
|
mock_meta_task1, mock_meta_task2 = AsyncMock(), AsyncMock()
|
||||||
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
|
self.task_pool._group_meta_tasks_running = {FOO: {mock_meta_task1()}, BAR: {mock_meta_task2()}}
|
||||||
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
mock_cancelled_meta_task = AsyncMock(side_effect=CancelledError)
|
||||||
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
self.task_pool._meta_tasks_cancelled = {mock_cancelled_meta_task()}
|
||||||
|
# Actual tasks:
|
||||||
|
mock_running_func = AsyncMock()
|
||||||
|
mock_ended_func, mock_cancelled_func = AsyncMock(), AsyncMock(side_effect=Exception)
|
||||||
|
self.task_pool._tasks_ended = {123: mock_ended_func()}
|
||||||
|
self.task_pool._tasks_cancelled = {456: mock_cancelled_func()}
|
||||||
|
self.task_pool._tasks_running = {789: mock_running_func()}
|
||||||
|
|
||||||
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
self.assertIsNone(await self.task_pool.gather_and_close(return_exceptions=True))
|
||||||
mock_base_gather_and_close.assert_awaited_once_with(return_exceptions=True)
|
|
||||||
|
mock_lock.assert_called_once_with()
|
||||||
|
# Meta tasks:
|
||||||
mock_meta_task1.assert_awaited_once_with()
|
mock_meta_task1.assert_awaited_once_with()
|
||||||
mock_meta_task2.assert_awaited_once_with()
|
mock_meta_task2.assert_awaited_once_with()
|
||||||
mock_cancelled_meta_task.assert_awaited_once_with()
|
mock_cancelled_meta_task.assert_awaited_once_with()
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._group_meta_tasks_running)
|
||||||
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
self.assertSetEqual(EMPTY_SET, self.task_pool._meta_tasks_cancelled)
|
||||||
|
# Actual tasks:
|
||||||
|
mock_ended_func.assert_awaited_once_with()
|
||||||
|
mock_cancelled_func.assert_awaited_once_with()
|
||||||
|
mock_running_func.assert_awaited_once_with()
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
|
self.assertTrue(self.task_pool._closed.is_set())
|
||||||
|
|
||||||
@patch.object(pool, 'datetime')
|
async def test_until_closed(self):
|
||||||
def test__generate_group_name(self, mock_datetime: MagicMock):
|
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
|
||||||
|
output = await self.task_pool.until_closed()
|
||||||
|
self.assertEqual(FOO, output)
|
||||||
|
self.task_pool._closed.wait.assert_awaited_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskPoolTestCase(CommonTestCase):
|
||||||
|
TEST_CLASS = pool.TaskPool
|
||||||
|
task_pool: pool.TaskPool
|
||||||
|
|
||||||
|
def test__generate_group_name(self):
|
||||||
prefix, func = 'x y z', AsyncMock(__name__=BAR)
|
prefix, func = 'x y z', AsyncMock(__name__=BAR)
|
||||||
dt = datetime(1776, 7, 4, 0, 0, 1)
|
base_name = f'{prefix}-{BAR}-group'
|
||||||
mock_datetime.now = MagicMock(return_value=dt)
|
self.task_pool._task_groups = {
|
||||||
expected_output = f'{prefix}_{BAR}_{dt.strftime(DATETIME_FORMAT)}'
|
f'{base_name}-0': MagicMock(),
|
||||||
output = pool.TaskPool._generate_group_name(prefix, func)
|
f'{base_name}-1': MagicMock(),
|
||||||
|
f'{base_name}-100': MagicMock(),
|
||||||
|
}
|
||||||
|
expected_output = f'{base_name}-2'
|
||||||
|
output = self.task_pool._generate_group_name(prefix, func)
|
||||||
self.assertEqual(expected_output, output)
|
self.assertEqual(expected_output, output)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_start_task')
|
@patch.object(pool.TaskPool, '_start_task')
|
||||||
async def test__apply_num(self, mock__start_task: AsyncMock):
|
async def test__apply_spawner(self, mock__start_task: AsyncMock):
|
||||||
group_name = FOO + BAR
|
grp_name = FOO + BAR
|
||||||
mock_awaitable = object()
|
mock_awaitable1, mock_awaitable2 = object(), object()
|
||||||
mock_func = MagicMock(return_value=mock_awaitable)
|
mock_func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
||||||
args, kwargs, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
|
args, kw, num = (FOO, BAR), {'a': 1, 'b': 2}, 3
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, kwargs, num, end_cb, cancel_cb))
|
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, kw, num, end_cb, cancel_cb))
|
||||||
mock_func.assert_has_calls(3 * [call(*args, **kwargs)])
|
mock_func.assert_has_calls(num * [call(*args, **kw)])
|
||||||
mock__start_task.assert_has_awaits(3 * [
|
mock__start_task.assert_has_awaits([
|
||||||
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||||
|
call(mock_awaitable2, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||||
])
|
])
|
||||||
|
|
||||||
mock_func.reset_mock()
|
mock_func.reset_mock(side_effect=True)
|
||||||
mock__start_task.reset_mock()
|
mock__start_task.reset_mock()
|
||||||
|
|
||||||
self.assertIsNone(await self.task_pool._apply_num(group_name, mock_func, args, None, num, end_cb, cancel_cb))
|
# Simulate cancellation while the second task is being started.
|
||||||
mock_func.assert_has_calls(num * [call(*args)])
|
mock__start_task.side_effect = [None, CancelledError, None]
|
||||||
mock__start_task.assert_has_awaits(num * [
|
mock_coroutine_to_close = MagicMock()
|
||||||
call(mock_awaitable, group_name=group_name, end_callback=end_cb, cancel_callback=cancel_cb)
|
mock_func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
||||||
|
self.assertIsNone(await self.task_pool._apply_spawner(grp_name, mock_func, args, None, num, end_cb, cancel_cb))
|
||||||
|
mock_func.assert_has_calls(2 * [call(*args)])
|
||||||
|
mock__start_task.assert_has_awaits([
|
||||||
|
call(mock_awaitable1, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||||
|
call(mock_coroutine_to_close, group_name=grp_name, end_callback=end_cb, cancel_callback=cancel_cb),
|
||||||
])
|
])
|
||||||
|
mock_coroutine_to_close.close.assert_called_once_with()
|
||||||
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.TaskPool, '_apply_num', new_callable=MagicMock())
|
@patch.object(pool.TaskPool, '_apply_spawner', new_callable=MagicMock())
|
||||||
@patch.object(pool, 'TaskGroupRegister')
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
@patch.object(pool.TaskPool, '_generate_group_name')
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
@patch.object(pool.BaseTaskPool, '_check_start')
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
async def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
|
def test_apply(self, mock__check_start: MagicMock, mock__generate_group_name: MagicMock,
|
||||||
mock_reg_cls: MagicMock, mock__apply_num: MagicMock, mock_create_task: MagicMock):
|
mock_reg_cls: MagicMock, mock__apply_spawner: MagicMock, mock_create_task: MagicMock):
|
||||||
mock__generate_group_name.return_value = generated_name = 'name 123'
|
mock__generate_group_name.return_value = generated_name = 'name 123'
|
||||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
mock__apply_num.return_value = mock_apply_coroutine = object()
|
mock__apply_spawner.return_value = mock_apply_coroutine = object()
|
||||||
mock_task_future = AsyncMock()
|
mock_create_task.return_value = fake_task = object()
|
||||||
mock_create_task.return_value = mock_task_future()
|
|
||||||
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
|
mock_func, num, group_name = MagicMock(), 3, FOO + BAR
|
||||||
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
args, kwargs = (FOO, BAR), {'a': 1, 'b': 2}
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
|
|
||||||
|
self.task_pool._task_groups = {group_name: 'causes error'}
|
||||||
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
|
self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
||||||
|
mock__check_start.assert_called_once_with(function=mock_func)
|
||||||
|
mock__apply_spawner.assert_not_called()
|
||||||
|
mock_create_task.assert_not_called()
|
||||||
|
|
||||||
|
mock__check_start.reset_mock()
|
||||||
self.task_pool._task_groups = {}
|
self.task_pool._task_groups = {}
|
||||||
|
|
||||||
def check_assertions(_group_name, _output):
|
def check_assertions(_group_name, _output):
|
||||||
self.assertEqual(_group_name, _output)
|
self.assertEqual(_group_name, _output)
|
||||||
mock__check_start.assert_called_once_with(function=mock_func)
|
mock__check_start.assert_called_once_with(function=mock_func)
|
||||||
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
|
self.assertEqual(mock_group_reg, self.task_pool._task_groups[_group_name])
|
||||||
mock_group_reg.__aenter__.assert_awaited_once_with()
|
mock__apply_spawner.assert_called_once_with(_group_name, mock_func, args, kwargs, num,
|
||||||
mock__apply_num.assert_called_once_with(_group_name, mock_func, args, kwargs, num, end_cb, cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock_create_task.assert_called_once_with(mock_apply_coroutine)
|
mock_create_task.assert_called_once_with(mock_apply_coroutine)
|
||||||
mock_group_reg.__aexit__.assert_awaited_once()
|
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
|
||||||
mock_task_future.assert_awaited_once_with()
|
|
||||||
|
|
||||||
output = await self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
output = self.task_pool.apply(mock_func, args, kwargs, num, group_name, end_cb, cancel_cb)
|
||||||
check_assertions(group_name, output)
|
check_assertions(group_name, output)
|
||||||
mock__generate_group_name.assert_not_called()
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
mock__check_start.reset_mock()
|
mock__check_start.reset_mock()
|
||||||
self.task_pool._task_groups.clear()
|
self.task_pool._task_groups.clear()
|
||||||
mock_group_reg.__aenter__.reset_mock()
|
mock__apply_spawner.reset_mock()
|
||||||
mock__apply_num.reset_mock()
|
|
||||||
mock_create_task.reset_mock()
|
mock_create_task.reset_mock()
|
||||||
mock_group_reg.__aexit__.reset_mock()
|
|
||||||
mock_task_future = AsyncMock()
|
|
||||||
mock_create_task.return_value = mock_task_future()
|
|
||||||
|
|
||||||
output = await self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
|
output = self.task_pool.apply(mock_func, args, kwargs, num, None, end_cb, cancel_cb)
|
||||||
check_assertions(generated_name, output)
|
check_assertions(generated_name, output)
|
||||||
mock__generate_group_name.assert_called_once_with('apply', mock_func)
|
mock__generate_group_name.assert_called_once_with('apply', mock_func)
|
||||||
|
|
||||||
@patch.object(pool, 'Queue')
|
|
||||||
async def test__queue_producer(self, mock_queue_cls: MagicMock):
|
|
||||||
mock_put = AsyncMock()
|
|
||||||
mock_queue_cls.return_value = mock_queue = MagicMock(put=mock_put)
|
|
||||||
item1, item2, item3 = FOO, 420, 69
|
|
||||||
arg_iter = iter([item1, item2, item3])
|
|
||||||
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
|
||||||
mock_put.assert_has_awaits([call(item1), call(item2), call(item3), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
|
||||||
with self.assertRaises(StopIteration):
|
|
||||||
next(arg_iter)
|
|
||||||
|
|
||||||
mock_put.reset_mock()
|
|
||||||
|
|
||||||
mock_put.side_effect = [CancelledError, None]
|
|
||||||
arg_iter = iter([item1, item2, item3])
|
|
||||||
mock_queue.get_nowait.side_effect = [item2, item3, QueueEmpty]
|
|
||||||
self.assertIsNone(await self.task_pool._queue_producer(mock_queue, arg_iter, FOO + BAR))
|
|
||||||
mock_put.assert_has_awaits([call(item1), call(pool.TaskPool._QUEUE_END_SENTINEL)])
|
|
||||||
mock_queue.get_nowait.assert_has_calls([call(), call(), call()])
|
|
||||||
mock_queue.item_processed.assert_has_calls([call(), call()])
|
|
||||||
self.assertListEqual([item2, item3], list(arg_iter))
|
|
||||||
|
|
||||||
@patch.object(pool, 'execute_optional')
|
@patch.object(pool, 'execute_optional')
|
||||||
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
|
async def test__get_map_end_callback(self, mock_execute_optional: AsyncMock):
|
||||||
semaphore, mock_end_cb = Semaphore(1), MagicMock()
|
semaphore, mock_end_cb = Semaphore(1), MagicMock()
|
||||||
@ -597,144 +578,176 @@ class TaskPoolTestCase(CommonTestCase):
|
|||||||
@patch.object(pool, 'Semaphore')
|
@patch.object(pool, 'Semaphore')
|
||||||
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
|
async def test__queue_consumer(self, mock_semaphore_cls: MagicMock, mock__get_map_end_callback: MagicMock,
|
||||||
mock__start_task: AsyncMock, mock_star_function: MagicMock):
|
mock__start_task: AsyncMock, mock_star_function: MagicMock):
|
||||||
mock_semaphore_cls.return_value = semaphore = Semaphore(3)
|
n = 2
|
||||||
|
mock_semaphore_cls.return_value = semaphore = Semaphore(n)
|
||||||
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
mock__get_map_end_callback.return_value = map_cb = MagicMock()
|
||||||
awaitable = 'totally an awaitable'
|
awaitable1, awaitable2 = 'totally an awaitable', object()
|
||||||
mock_star_function.side_effect = [awaitable, awaitable, Exception()]
|
mock_star_function.side_effect = [awaitable1, Exception(), awaitable2]
|
||||||
arg1, arg2, bad = 123456789, 'function argument', None
|
arg1, arg2, bad = 123456789, 'function argument', None
|
||||||
mock_q_maxsize = 3
|
args = [arg1, bad, arg2]
|
||||||
mock_q = MagicMock(__aenter__=AsyncMock(side_effect=[arg1, arg2, bad, pool.TaskPool._QUEUE_END_SENTINEL]),
|
grp_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
||||||
__aexit__=AsyncMock(), maxsize=mock_q_maxsize)
|
|
||||||
group_name, mock_func, stars = 'whatever', MagicMock(__name__="mock"), 3
|
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
self.assertIsNone(await self.task_pool._queue_consumer(mock_q, group_name, mock_func, stars, end_cb, cancel_cb))
|
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, args, stars, end_cb, cancel_cb))
|
||||||
# We expect the semaphore to be acquired 3 times, then be released once after the exception occurs, then
|
# We initialized the semaphore with a value of 2. It should have been acquired twice. We expect it be locked.
|
||||||
# acquired once more when the `_QUEUE_END_SENTINEL` is reached. Since we initialized it with a value of 3,
|
|
||||||
# at the end of the loop, we expect it be locked.
|
|
||||||
self.assertTrue(semaphore.locked())
|
self.assertTrue(semaphore.locked())
|
||||||
mock_semaphore_cls.assert_called_once_with(mock_q_maxsize)
|
mock_semaphore_cls.assert_called_once_with(n)
|
||||||
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
mock__get_map_end_callback.assert_called_once_with(semaphore, actual_end_callback=end_cb)
|
||||||
mock__start_task.assert_has_awaits(2 * [
|
mock__start_task.assert_has_awaits([
|
||||||
call(awaitable, group_name=group_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
call(awaitable1, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
|
||||||
|
call(awaitable2, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb),
|
||||||
])
|
])
|
||||||
mock_star_function.assert_has_calls([
|
mock_star_function.assert_has_calls([
|
||||||
call(mock_func, arg1, arg_stars=stars),
|
call(mock_func, arg1, arg_stars=stars),
|
||||||
call(mock_func, arg2, arg_stars=stars),
|
call(mock_func, bad, arg_stars=stars),
|
||||||
call(mock_func, bad, arg_stars=stars)
|
call(mock_func, arg2, arg_stars=stars)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
mock_semaphore_cls.reset_mock()
|
||||||
|
mock__get_map_end_callback.reset_mock()
|
||||||
|
mock__start_task.reset_mock()
|
||||||
|
mock_star_function.reset_mock(side_effect=True)
|
||||||
|
|
||||||
|
# With a CancelledError thrown while acquiring the semaphore:
|
||||||
|
mock_acquire = AsyncMock(side_effect=[True, CancelledError])
|
||||||
|
mock_semaphore_cls.return_value = mock_semaphore = MagicMock(acquire=mock_acquire)
|
||||||
|
mock_star_function.return_value = mock_coroutine = MagicMock()
|
||||||
|
arg_it = iter(arg for arg in (arg1, arg2, FOO))
|
||||||
|
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
|
||||||
|
mock_semaphore_cls.assert_called_once_with(n)
|
||||||
|
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
|
||||||
|
mock_star_function.assert_has_calls([
|
||||||
|
call(mock_func, arg1, arg_stars=stars),
|
||||||
|
call(mock_func, arg2, arg_stars=stars)
|
||||||
|
])
|
||||||
|
mock_acquire.assert_has_awaits([call(), call()])
|
||||||
|
mock__start_task.assert_awaited_once_with(mock_coroutine, group_name=grp_name, ignore_lock=True,
|
||||||
|
end_callback=map_cb, cancel_callback=cancel_cb)
|
||||||
|
mock_coroutine.close.assert_called_once_with()
|
||||||
|
mock_semaphore.release.assert_not_called()
|
||||||
|
self.assertEqual(FOO, next(arg_it))
|
||||||
|
|
||||||
|
mock_acquire.reset_mock(side_effect=True)
|
||||||
|
mock_semaphore_cls.reset_mock()
|
||||||
|
mock__get_map_end_callback.reset_mock()
|
||||||
|
mock__start_task.reset_mock()
|
||||||
|
mock_star_function.reset_mock(side_effect=True)
|
||||||
|
|
||||||
|
# With a CancelledError thrown while starting the task:
|
||||||
|
mock__start_task.side_effect = [None, CancelledError]
|
||||||
|
arg_it = iter(arg for arg in (arg1, arg2, FOO))
|
||||||
|
self.assertIsNone(await self.task_pool._arg_consumer(grp_name, n, mock_func, arg_it, stars, end_cb, cancel_cb))
|
||||||
|
mock_semaphore_cls.assert_called_once_with(n)
|
||||||
|
mock__get_map_end_callback.assert_called_once_with(mock_semaphore, actual_end_callback=end_cb)
|
||||||
|
mock_star_function.assert_has_calls([
|
||||||
|
call(mock_func, arg1, arg_stars=stars),
|
||||||
|
call(mock_func, arg2, arg_stars=stars)
|
||||||
|
])
|
||||||
|
mock_acquire.assert_has_awaits([call(), call()])
|
||||||
|
mock__start_task.assert_has_awaits(2 * [
|
||||||
|
call(mock_coroutine, group_name=grp_name, ignore_lock=True, end_callback=map_cb, cancel_callback=cancel_cb)
|
||||||
|
])
|
||||||
|
mock_coroutine.close.assert_called_once_with()
|
||||||
|
mock_semaphore.release.assert_called_once_with()
|
||||||
|
self.assertEqual(FOO, next(arg_it))
|
||||||
|
|
||||||
@patch.object(pool, 'create_task')
|
@patch.object(pool, 'create_task')
|
||||||
@patch.object(pool.TaskPool, '_queue_consumer', new_callable=MagicMock)
|
@patch.object(pool.TaskPool, '_arg_consumer', new_callable=MagicMock)
|
||||||
@patch.object(pool.TaskPool, '_queue_producer', new_callable=MagicMock)
|
|
||||||
@patch.object(pool, 'join_queue', new_callable=MagicMock)
|
|
||||||
@patch.object(pool, 'Queue')
|
|
||||||
@patch.object(pool, 'TaskGroupRegister')
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
@patch.object(pool.BaseTaskPool, '_check_start')
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
async def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock_queue_cls: MagicMock,
|
def test__map(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__arg_consumer: MagicMock,
|
||||||
mock_join_queue: MagicMock, mock__queue_producer: MagicMock, mock__queue_consumer: MagicMock,
|
mock_create_task: MagicMock):
|
||||||
mock_create_task: MagicMock):
|
|
||||||
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
mock_queue_cls.return_value = mock_q = MagicMock()
|
mock__arg_consumer.return_value = fake_consumer = object()
|
||||||
mock_join_queue.return_value = fake_join = object()
|
mock_create_task.return_value = fake_task = object()
|
||||||
mock__queue_producer.return_value = fake_producer = object()
|
|
||||||
mock__queue_consumer.return_value = fake_consumer = object()
|
|
||||||
fake_task1, fake_task2 = object(), object()
|
|
||||||
mock_create_task.side_effect = [fake_task1, fake_task2]
|
|
||||||
|
|
||||||
group_name, group_size = 'onetwothree', 0
|
group_name, n = 'onetwothree', 0
|
||||||
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
|
func, arg_iter, stars = AsyncMock(), [55, 66, 77], 3
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
|
||||||
mock__check_start.assert_called_once_with(function=func)
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
|
|
||||||
mock__check_start.reset_mock()
|
mock__check_start.reset_mock()
|
||||||
|
|
||||||
group_size = 1234
|
n = 1234
|
||||||
self.task_pool._task_groups = {group_name: MagicMock()}
|
self.task_pool._task_groups = {group_name: MagicMock()}
|
||||||
|
|
||||||
with self.assertRaises(exceptions.InvalidGroupName):
|
with self.assertRaises(exceptions.InvalidGroupName):
|
||||||
await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb)
|
self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb)
|
||||||
mock__check_start.assert_called_once_with(function=func)
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
|
|
||||||
mock__check_start.reset_mock()
|
mock__check_start.reset_mock()
|
||||||
|
|
||||||
self.task_pool._task_groups.clear()
|
self.task_pool._task_groups.clear()
|
||||||
self.task_pool._before_gathering = []
|
|
||||||
|
|
||||||
self.assertIsNone(await self.task_pool._map(group_name, group_size, func, arg_iter, stars, end_cb, cancel_cb))
|
self.assertIsNone(self.task_pool._map(group_name, n, func, arg_iter, stars, end_cb, cancel_cb))
|
||||||
mock__check_start.assert_called_once_with(function=func)
|
mock__check_start.assert_called_once_with(function=func)
|
||||||
mock_reg_cls.assert_called_once_with()
|
mock_reg_cls.assert_called_once_with()
|
||||||
self.task_pool._task_groups[group_name] = mock_group_reg
|
self.task_pool._task_groups[group_name] = mock_group_reg
|
||||||
mock_group_reg.__aenter__.assert_awaited_once_with()
|
mock__arg_consumer.assert_called_once_with(group_name, n, func, arg_iter, stars,
|
||||||
mock_queue_cls.assert_called_once_with(maxsize=group_size)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock_join_queue.assert_called_once_with(mock_q)
|
mock_create_task.assert_called_once_with(fake_consumer)
|
||||||
self.assertListEqual([fake_join], self.task_pool._before_gathering)
|
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[group_name])
|
||||||
mock__queue_producer.assert_called_once()
|
|
||||||
mock__queue_consumer.assert_called_once_with(mock_q, group_name, func, stars, end_cb, cancel_cb)
|
|
||||||
mock_create_task.assert_has_calls([call(fake_producer), call(fake_consumer)])
|
|
||||||
self.assertSetEqual({fake_task1, fake_task2}, self.task_pool._group_meta_tasks_running[group_name])
|
|
||||||
mock_group_reg.__aexit__.assert_awaited_once()
|
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
@patch.object(pool.TaskPool, '_generate_group_name')
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
async def test_map(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
def test_map(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
|
||||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
arg_iter, group_size, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
|
arg_iter, num_concurrent, group_name = (FOO, BAR, 1, 2, 3), 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
output = await self.task_pool.map(mock_func, arg_iter, group_size, group_name, end_cb, cancel_cb)
|
output = self.task_pool.map(mock_func, arg_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||||
self.assertEqual(group_name, output)
|
self.assertEqual(group_name, output)
|
||||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, arg_iter, 0,
|
mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, arg_iter, 0,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_not_called()
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
mock__map.reset_mock()
|
mock__map.reset_mock()
|
||||||
output = await self.task_pool.map(mock_func, arg_iter, group_size, None, end_cb, cancel_cb)
|
output = self.task_pool.map(mock_func, arg_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||||
self.assertEqual(generated_name, output)
|
self.assertEqual(generated_name, output)
|
||||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, arg_iter, 0,
|
mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, arg_iter, 0,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_called_once_with('map', mock_func)
|
mock__generate_group_name.assert_called_once_with('map', mock_func)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
@patch.object(pool.TaskPool, '_generate_group_name')
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
async def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
def test_starmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
|
||||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
args_iter, group_size, group_name = ([FOO], [BAR]), 2, FOO + BAR
|
args_iter, num_concurrent, group_name = ([FOO], [BAR]), 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
output = await self.task_pool.starmap(mock_func, args_iter, group_size, group_name, end_cb, cancel_cb)
|
output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||||
self.assertEqual(group_name, output)
|
self.assertEqual(group_name, output)
|
||||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, args_iter, 1,
|
mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, args_iter, 1,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_not_called()
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
mock__map.reset_mock()
|
mock__map.reset_mock()
|
||||||
output = await self.task_pool.starmap(mock_func, args_iter, group_size, None, end_cb, cancel_cb)
|
output = self.task_pool.starmap(mock_func, args_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||||
self.assertEqual(generated_name, output)
|
self.assertEqual(generated_name, output)
|
||||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, args_iter, 1,
|
mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, args_iter, 1,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
|
mock__generate_group_name.assert_called_once_with('starmap', mock_func)
|
||||||
|
|
||||||
@patch.object(pool.TaskPool, '_map')
|
@patch.object(pool.TaskPool, '_map')
|
||||||
@patch.object(pool.TaskPool, '_generate_group_name')
|
@patch.object(pool.TaskPool, '_generate_group_name')
|
||||||
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: AsyncMock):
|
async def test_doublestarmap(self, mock__generate_group_name: MagicMock, mock__map: MagicMock):
|
||||||
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
mock__generate_group_name.return_value = generated_name = 'name 1 2 3'
|
||||||
mock_func = MagicMock()
|
mock_func = MagicMock()
|
||||||
kwargs_iter, group_size, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
|
kw_iter, num_concurrent, group_name = [{'a': FOO}, {'a': BAR}], 2, FOO + BAR
|
||||||
end_cb, cancel_cb = MagicMock(), MagicMock()
|
end_cb, cancel_cb = MagicMock(), MagicMock()
|
||||||
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, group_name, end_cb, cancel_cb)
|
output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, group_name, end_cb, cancel_cb)
|
||||||
self.assertEqual(group_name, output)
|
self.assertEqual(group_name, output)
|
||||||
mock__map.assert_awaited_once_with(group_name, group_size, mock_func, kwargs_iter, 2,
|
mock__map.assert_called_once_with(group_name, num_concurrent, mock_func, kw_iter, 2,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_not_called()
|
mock__generate_group_name.assert_not_called()
|
||||||
|
|
||||||
mock__map.reset_mock()
|
mock__map.reset_mock()
|
||||||
output = await self.task_pool.doublestarmap(mock_func, kwargs_iter, group_size, None, end_cb, cancel_cb)
|
output = self.task_pool.doublestarmap(mock_func, kw_iter, num_concurrent, None, end_cb, cancel_cb)
|
||||||
self.assertEqual(generated_name, output)
|
self.assertEqual(generated_name, output)
|
||||||
mock__map.assert_awaited_once_with(generated_name, group_size, mock_func, kwargs_iter, 2,
|
mock__map.assert_called_once_with(generated_name, num_concurrent, mock_func, kw_iter, 2,
|
||||||
end_callback=end_cb, cancel_callback=cancel_cb)
|
end_callback=end_cb, cancel_callback=cancel_cb)
|
||||||
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
|
mock__generate_group_name.assert_called_once_with('doublestarmap', mock_func)
|
||||||
|
|
||||||
|
|
||||||
@ -749,13 +762,15 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
TEST_POOL_CANCEL_CB = MagicMock()
|
TEST_POOL_CANCEL_CB = MagicMock()
|
||||||
|
|
||||||
def get_task_pool_init_params(self) -> dict:
|
def get_task_pool_init_params(self) -> dict:
|
||||||
return super().get_task_pool_init_params() | {
|
params = super().get_task_pool_init_params()
|
||||||
|
params.update({
|
||||||
'func': self.TEST_POOL_FUNC,
|
'func': self.TEST_POOL_FUNC,
|
||||||
'args': self.TEST_POOL_ARGS,
|
'args': self.TEST_POOL_ARGS,
|
||||||
'kwargs': self.TEST_POOL_KWARGS,
|
'kwargs': self.TEST_POOL_KWARGS,
|
||||||
'end_callback': self.TEST_POOL_END_CB,
|
'end_callback': self.TEST_POOL_END_CB,
|
||||||
'cancel_callback': self.TEST_POOL_CANCEL_CB,
|
'cancel_callback': self.TEST_POOL_CANCEL_CB,
|
||||||
}
|
})
|
||||||
|
return params
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
self.base_class_init_patcher = patch.object(pool.BaseTaskPool, '__init__')
|
||||||
@ -764,6 +779,7 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
self.base_class_init_patcher.stop()
|
self.base_class_init_patcher.stop()
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
|
self.assertEqual(self.TEST_POOL_FUNC, self.task_pool._func)
|
||||||
@ -780,23 +796,54 @@ class SimpleTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
|
self.assertEqual(self.TEST_POOL_FUNC.__name__, self.task_pool.func_name)
|
||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_task')
|
@patch.object(pool.SimpleTaskPool, '_start_task')
|
||||||
async def test__start_one(self, mock__start_task: AsyncMock):
|
async def test__start_num(self, mock__start_task: AsyncMock):
|
||||||
mock__start_task.return_value = expected_output = 99
|
group_name = FOO + BAR + 'abc'
|
||||||
self.task_pool._func = MagicMock(return_value=BAR)
|
mock_awaitable1, mock_awaitable2 = object(), object()
|
||||||
output = await self.task_pool._start_one()
|
self.task_pool._func = MagicMock(side_effect=[mock_awaitable1, Exception(), mock_awaitable2], __name__='func')
|
||||||
self.assertEqual(expected_output, output)
|
num = 3
|
||||||
self.task_pool._func.assert_called_once_with(*self.task_pool._args, **self.task_pool._kwargs)
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
mock__start_task.assert_awaited_once_with(BAR, end_callback=self.task_pool._end_callback,
|
self.task_pool._func.assert_has_calls(num * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
cancel_callback=self.task_pool._cancel_callback)
|
call_kw = {
|
||||||
|
'group_name': group_name,
|
||||||
|
'end_callback': self.task_pool._end_callback,
|
||||||
|
'cancel_callback': self.task_pool._cancel_callback
|
||||||
|
}
|
||||||
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_awaitable2, **call_kw)])
|
||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, '_start_one')
|
self.task_pool._func.reset_mock(side_effect=True)
|
||||||
async def test_start(self, mock__start_one: AsyncMock):
|
mock__start_task.reset_mock()
|
||||||
mock__start_one.return_value = FOO
|
|
||||||
|
# Simulate cancellation while the second task is being started.
|
||||||
|
mock__start_task.side_effect = [None, CancelledError, None]
|
||||||
|
mock_coroutine_to_close = MagicMock()
|
||||||
|
self.task_pool._func.side_effect = [mock_awaitable1, mock_coroutine_to_close, 'never called']
|
||||||
|
self.assertIsNone(await self.task_pool._start_num(num, group_name))
|
||||||
|
self.task_pool._func.assert_has_calls(2 * [call(*self.task_pool._args, **self.task_pool._kwargs)])
|
||||||
|
mock__start_task.assert_has_awaits([call(mock_awaitable1, **call_kw), call(mock_coroutine_to_close, **call_kw)])
|
||||||
|
mock_coroutine_to_close.close.assert_called_once_with()
|
||||||
|
|
||||||
|
@patch.object(pool, 'create_task')
|
||||||
|
@patch.object(pool.SimpleTaskPool, '_start_num', new_callable=MagicMock())
|
||||||
|
@patch.object(pool, 'TaskGroupRegister')
|
||||||
|
@patch.object(pool.BaseTaskPool, '_check_start')
|
||||||
|
def test_start(self, mock__check_start: MagicMock, mock_reg_cls: MagicMock, mock__start_num: AsyncMock,
|
||||||
|
mock_create_task: MagicMock):
|
||||||
|
mock_group_reg = set_up_mock_group_register(mock_reg_cls)
|
||||||
|
mock__start_num.return_value = mock_start_num_coroutine = object()
|
||||||
|
mock_create_task.return_value = fake_task = object()
|
||||||
|
self.task_pool._task_groups = {}
|
||||||
|
self.task_pool._group_meta_tasks_running = {}
|
||||||
num = 5
|
num = 5
|
||||||
output = await self.task_pool.start(num)
|
self.task_pool._start_calls = 42
|
||||||
expected_output = num * [FOO]
|
expected_group_name = 'start-group-42'
|
||||||
self.assertListEqual(expected_output, output)
|
output = self.task_pool.start(num)
|
||||||
mock__start_one.assert_has_awaits(num * [call()])
|
self.assertEqual(expected_group_name, output)
|
||||||
|
mock__check_start.assert_called_once_with(function=self.TEST_POOL_FUNC)
|
||||||
|
self.assertEqual(43, self.task_pool._start_calls)
|
||||||
|
self.assertEqual(mock_group_reg, self.task_pool._task_groups[expected_group_name])
|
||||||
|
mock__start_num.assert_called_once_with(num, expected_group_name)
|
||||||
|
mock_create_task.assert_called_once_with(mock_start_num_coroutine)
|
||||||
|
self.assertSetEqual({fake_task}, self.task_pool._group_meta_tasks_running[expected_group_name])
|
||||||
|
|
||||||
@patch.object(pool.SimpleTaskPool, 'cancel')
|
@patch.object(pool.SimpleTaskPool, 'cancel')
|
||||||
def test_stop(self, mock_cancel: MagicMock):
|
def test_stop(self, mock_cancel: MagicMock):
|
||||||
|
@ -39,12 +39,11 @@ async def work(n: int) -> None:
|
|||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
|
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
|
||||||
await pool.start(3) # launches work tasks 0, 1, and 2
|
pool.start(3) # launches work tasks 0, 1, and 2
|
||||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||||
await pool.start() # launches work task 3
|
pool.start(1) # launches work task 3
|
||||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||||
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
|
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
|
||||||
pool.lock() # required for the last line
|
|
||||||
await pool.gather_and_close() # awaits all tasks, then flushes the pool
|
await pool.gather_and_close() # awaits all tasks, then flushes the pool
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +122,7 @@ async def main() -> None:
|
|||||||
pool = TaskPool(3)
|
pool = TaskPool(3)
|
||||||
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
|
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
|
||||||
print("> Called `apply`")
|
print("> Called `apply`")
|
||||||
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
|
pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
|
||||||
# Let the tasks work for a bit.
|
# Let the tasks work for a bit.
|
||||||
await asyncio.sleep(1.5)
|
await asyncio.sleep(1.5)
|
||||||
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
|
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
|
||||||
@ -135,11 +134,9 @@ async def main() -> None:
|
|||||||
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
||||||
# only once there is room in the pool and no more than one other task of these new ones is running.
|
# only once there is room in the pool and no more than one other task of these new ones is running.
|
||||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||||
await pool.starmap(other_work, args_list, group_size=2)
|
pool.starmap(other_work, args_list, num_concurrent=2)
|
||||||
print("> Called `starmap`")
|
print("> Called `starmap`")
|
||||||
# Now we lock the pool, so that we can safely await all our tasks.
|
# We block, until all tasks have ended.
|
||||||
pool.lock()
|
|
||||||
# Finally, we block, until all tasks have ended.
|
|
||||||
print("> Calling `gather_and_close`...")
|
print("> Calling `gather_and_close`...")
|
||||||
await pool.gather_and_close()
|
await pool.gather_and_close()
|
||||||
print("> Done.")
|
print("> Done.")
|
||||||
@ -199,7 +196,7 @@ Started TaskPool-0_Task-3
|
|||||||
> other_work with 15
|
> other_work with 15
|
||||||
Ended TaskPool-0_Task-0
|
Ended TaskPool-0_Task-0
|
||||||
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
|
Ended TaskPool-0_Task-1 <--- these two end and free up two more slots in the pool
|
||||||
Started TaskPool-0_Task-4 <--- since the group size is set to 2, Task-5 will not start
|
Started TaskPool-0_Task-4 <--- since `num_concurrent` is set to 2, Task-5 will not start
|
||||||
> work with 190
|
> work with 190
|
||||||
> work with 190
|
> work with 190
|
||||||
> other_work with 16
|
> other_work with 16
|
||||||
|
@ -67,7 +67,7 @@ async def main() -> None:
|
|||||||
for item in range(100):
|
for item in range(100):
|
||||||
q.put_nowait(item)
|
q.put_nowait(item)
|
||||||
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
|
pool = SimpleTaskPool(worker, args=(q,)) # initializes the pool
|
||||||
await pool.start(3) # launches three worker tasks
|
pool.start(3) # launches three worker tasks
|
||||||
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
|
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
|
||||||
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||||
await q.join()
|
await q.join()
|
||||||
@ -75,7 +75,6 @@ async def main() -> None:
|
|||||||
control_server_task.cancel()
|
control_server_task.cancel()
|
||||||
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
# Since our workers should now be stuck waiting for more items to pick from the queue, but no items are left,
|
||||||
# we can now safely cancel their tasks.
|
# we can now safely cancel their tasks.
|
||||||
pool.lock()
|
|
||||||
pool.stop_all()
|
pool.stop_all()
|
||||||
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
|
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
|
||||||
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
||||||
|
Loading…
Reference in New Issue
Block a user