generated from daniil-berg/boilerplate-py
Compare commits
8 Commits
v0.4.0-lw
...
360fe578d7
Author | SHA1 | Date | |
---|---|---|---|
360fe578d7 | |||
7c66604ad0 | |||
287906a218 | |||
ce0f9a1f65 | |||
5dad4ab0c7 | |||
ae6bb1bd17 | |||
e501a849f3 | |||
ed6badb088 |
@ -5,7 +5,6 @@ omit =
|
||||
.venv/*
|
||||
|
||||
[report]
|
||||
fail_under = 100
|
||||
show_missing = True
|
||||
skip_covered = False
|
||||
exclude_lines =
|
||||
|
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = asyncio-taskpool
|
||||
version = 0.4.0
|
||||
version = 0.5.1
|
||||
author = Daniil Fajnberg
|
||||
author_email = mail@daniil.fajnberg.de
|
||||
description = Dynamically manage pools of asyncio tasks
|
||||
|
@ -20,4 +20,4 @@ Brings the main classes up to package level for import convenience.
|
||||
|
||||
|
||||
from .pool import TaskPool, SimpleTaskPool
|
||||
from .server import UnixControlServer
|
||||
from .server import TCPControlServer, UnixControlServer
|
||||
|
@ -25,15 +25,16 @@ from asyncio import run
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from .client import ControlClient, UnixControlClient
|
||||
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||
from .constants import PACKAGE_NAME
|
||||
from .pool import TaskPool
|
||||
from .server import ControlServer
|
||||
from .server import TCPControlServer, UnixControlServer
|
||||
|
||||
|
||||
CONN_TYPE = 'conn_type'
|
||||
UNIX, TCP = 'unix', 'tcp'
|
||||
SOCKET_PATH = 'path'
|
||||
HOST, PORT = 'host', 'port'
|
||||
|
||||
|
||||
def parse_cli() -> Dict[str, Any]:
|
||||
@ -46,7 +47,18 @@ def parse_cli() -> Dict[str, Any]:
|
||||
unix_parser.add_argument(
|
||||
SOCKET_PATH,
|
||||
type=Path,
|
||||
help=f"Path to the unix socket on which the {ControlServer.__name__} for the {TaskPool.__name__} is listening."
|
||||
help=f"Path to the unix socket on which the {UnixControlServer.__name__} for the {TaskPool.__name__} is "
|
||||
f"listening."
|
||||
)
|
||||
tcp_parser = subparsers.add_parser(TCP, help="Connect via TCP socket")
|
||||
tcp_parser.add_argument(
|
||||
HOST,
|
||||
help=f"IP address or url that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
||||
)
|
||||
tcp_parser.add_argument(
|
||||
PORT,
|
||||
type=int,
|
||||
help=f"Port that the {TCPControlServer.__name__} for the {TaskPool.__name__} is listening on."
|
||||
)
|
||||
return vars(parser.parse_args())
|
||||
|
||||
@ -56,8 +68,7 @@ async def main():
|
||||
if kwargs[CONN_TYPE] == UNIX:
|
||||
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
|
||||
elif kwargs[CONN_TYPE] == TCP:
|
||||
# TODO: Implement the TCP client class
|
||||
client = UnixControlClient(socket_path=kwargs[SOCKET_PATH])
|
||||
client = TCPControlClient(host=kwargs[HOST], port=kwargs[PORT])
|
||||
else:
|
||||
print("Invalid connection type", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
@ -23,9 +23,9 @@ import json
|
||||
import shutil
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
|
||||
from asyncio.streams import StreamReader, StreamWriter, open_connection
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from .constants import CLIENT_EXIT, CLIENT_INFO, SESSION_MSG_BYTES
|
||||
from .types import ClientConnT, PathT
|
||||
@ -50,8 +50,8 @@ class ControlClient(ABC):
|
||||
"""
|
||||
Tries to connect to a socket using the provided arguments and return the associated reader-writer-pair.
|
||||
|
||||
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs` (unpacked)
|
||||
as keyword-arguments.
|
||||
This method will be invoked by the public `start()` method with the pre-defined internal `_conn_kwargs`
|
||||
(unpacked) as keyword-arguments.
|
||||
This method should return either a tuple of `asyncio.StreamReader` and `asyncio.StreamWriter` or a tuple of
|
||||
`None` and `None`, if it failed to establish the defined connection.
|
||||
"""
|
||||
@ -144,15 +144,36 @@ class ControlClient(ABC):
|
||||
print("Disconnected from control server.")
|
||||
|
||||
|
||||
class TCPControlClient(ControlClient):
|
||||
"""Task pool control client that expects a TCP socket to be exposed by the control server."""
|
||||
|
||||
def __init__(self, host: str, port: Union[int, str], **conn_kwargs) -> None:
|
||||
"""In addition to what the base class does, `host` and `port` are expected as non-optional arguments."""
|
||||
self._host = host
|
||||
self._port = port
|
||||
super().__init__(**conn_kwargs)
|
||||
|
||||
async def _open_connection(self, **kwargs) -> ClientConnT:
|
||||
"""
|
||||
Wrapper around the `asyncio.open_connection` function.
|
||||
|
||||
Returns a tuple of `None` and `None`, if the connection can not be established;
|
||||
otherwise, the stream-reader and -writer tuple is returned.
|
||||
"""
|
||||
try:
|
||||
return await open_connection(self._host, self._port, **kwargs)
|
||||
except ConnectionError as e:
|
||||
print(str(e), file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
|
||||
class UnixControlClient(ControlClient):
|
||||
"""Task pool control client that expects a unix socket to be exposed by the control server."""
|
||||
|
||||
def __init__(self, socket_path: PathT, **conn_kwargs) -> None:
|
||||
"""
|
||||
In addition to what the base class does, the `socket_path` is expected as a non-optional argument.
|
||||
|
||||
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
|
||||
"""
|
||||
"""In addition to what the base class does, the `socket_path` is expected as a non-optional argument."""
|
||||
from asyncio.streams import open_unix_connection
|
||||
self._open_unix_connection = open_unix_connection
|
||||
self._socket_path = Path(socket_path)
|
||||
super().__init__(**conn_kwargs)
|
||||
|
||||
@ -164,7 +185,7 @@ class UnixControlClient(ControlClient):
|
||||
otherwise, the stream-reader and -writer tuple is returned.
|
||||
"""
|
||||
try:
|
||||
return await open_unix_connection(self._socket_path, **kwargs)
|
||||
return await self._open_unix_connection(self._socket_path, **kwargs)
|
||||
except FileNotFoundError:
|
||||
print("No socket at", self._socket_path, file=sys.stderr)
|
||||
return None, None
|
||||
|
@ -37,10 +37,21 @@ class CLIENT_INFO:
|
||||
|
||||
class CMD:
|
||||
__slots__ = ()
|
||||
# Base commands:
|
||||
CMD = 'command'
|
||||
NAME = 'name'
|
||||
POOL_SIZE = 'pool-size'
|
||||
IS_LOCKED = 'is-locked'
|
||||
LOCK = 'lock'
|
||||
UNLOCK = 'unlock'
|
||||
NUM_RUNNING = 'num-running'
|
||||
NUM_CANCELLATIONS = 'num-cancellations'
|
||||
NUM_ENDED = 'num-ended'
|
||||
NUM_FINISHED = 'num-finished'
|
||||
IS_FULL = 'is-full'
|
||||
GET_GROUP_IDS = 'get-group-ids'
|
||||
|
||||
# Simple commands:
|
||||
START = 'start'
|
||||
STOP = 'stop'
|
||||
STOP_ALL = 'stop-all'
|
||||
|
@ -178,23 +178,26 @@ class BaseTaskPool:
|
||||
"""
|
||||
return self._enough_room.locked()
|
||||
|
||||
def get_task_group_ids(self, group_name: str) -> Set[int]:
|
||||
def get_group_ids(self, *group_names: str) -> Set[int]:
|
||||
"""
|
||||
Returns the set of IDs of all tasks in the specified group.
|
||||
Returns the set of IDs of all tasks in the specified groups.
|
||||
|
||||
Args:
|
||||
group_name: Must be a name of a task group that exists within the pool.
|
||||
*group_names: Each element must be a name of a task group that exists within the pool.
|
||||
|
||||
Returns:
|
||||
Set of integers representing the task IDs belonging to the specified group.
|
||||
Set of integers representing the task IDs belonging to the specified groups.
|
||||
|
||||
Raises:
|
||||
`InvalidGroupName` if no task group named `group_name` exists in the pool.
|
||||
`InvalidGroupName` if one of the specified `group_names` does not exist in the pool.
|
||||
"""
|
||||
try:
|
||||
return set(self._task_groups[group_name])
|
||||
except KeyError:
|
||||
raise exceptions.InvalidGroupName(f"No task group named {group_name} exists in this pool.")
|
||||
ids = set()
|
||||
for name in group_names:
|
||||
try:
|
||||
ids.update(self._task_groups[name])
|
||||
except KeyError:
|
||||
raise exceptions.InvalidGroupName(f"No task group named {name} exists in this pool.")
|
||||
return ids
|
||||
|
||||
def _check_start(self, *, awaitable: Awaitable = None, function: CoroutineFunc = None,
|
||||
ignore_lock: bool = False) -> None:
|
||||
|
@ -23,12 +23,12 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import AbstractServer
|
||||
from asyncio.exceptions import CancelledError
|
||||
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
|
||||
from asyncio.streams import StreamReader, StreamWriter, start_server
|
||||
from asyncio.tasks import Task, create_task
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from .client import ControlClient, UnixControlClient
|
||||
from .client import ControlClient, TCPControlClient, UnixControlClient
|
||||
from .pool import TaskPool, SimpleTaskPool
|
||||
from .session import ControlSession
|
||||
from .types import ConnectedCallbackT
|
||||
@ -132,16 +132,36 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
|
||||
return create_task(self._serve_forever())
|
||||
|
||||
|
||||
class TCPControlServer(ControlServer):
|
||||
"""Task pool control server class that exposes a TCP socket for control clients to connect to."""
|
||||
_client_class = TCPControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
self._host = server_kwargs.pop('host')
|
||||
self._port = server_kwargs.pop('port')
|
||||
super().__init__(pool, **server_kwargs)
|
||||
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
server = await start_server(client_connected_cb, self._host, self._port, **kwargs)
|
||||
log.debug("Opened socket at %s:%s", self._host, self._port)
|
||||
return server
|
||||
|
||||
def _final_callback(self) -> None:
|
||||
log.debug("Closed socket at %s:%s", self._host, self._port)
|
||||
|
||||
|
||||
class UnixControlServer(ControlServer):
|
||||
"""Task pool control server class that exposes a unix socket for control clients to connect to."""
|
||||
_client_class = UnixControlClient
|
||||
|
||||
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
|
||||
from asyncio.streams import start_unix_server
|
||||
self._start_unix_server = start_unix_server
|
||||
self._socket_path = Path(server_kwargs.pop('path'))
|
||||
super().__init__(pool, **server_kwargs)
|
||||
|
||||
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
|
||||
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs)
|
||||
log.debug("Opened socket '%s'", str(self._socket_path))
|
||||
return server
|
||||
|
||||
|
@ -23,7 +23,7 @@ import logging
|
||||
import json
|
||||
from argparse import ArgumentError, HelpFormatter
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
|
||||
|
||||
from .constants import CMD, SESSION_WRITER, SESSION_MSG_BYTES, CLIENT_INFO
|
||||
from .exceptions import HelpRequested, NotATaskPool, UnknownTaskPoolClass
|
||||
@ -108,19 +108,36 @@ class ControlSession:
|
||||
These include commands mapping to the following pool methods:
|
||||
- __str__
|
||||
- pool_size (get/set property)
|
||||
- is_locked
|
||||
- lock & unlock
|
||||
- num_running
|
||||
"""
|
||||
self._add_command(CMD.NAME, short_help=get_first_doc_line(self._pool.__class__.__str__))
|
||||
cls: Type[BaseTaskPool] = self._pool.__class__
|
||||
self._add_command(CMD.NAME, short_help=get_first_doc_line(cls.__str__))
|
||||
self._add_command(
|
||||
CMD.POOL_SIZE,
|
||||
short_help="Get/set the maximum number of tasks in the pool.",
|
||||
formatter_class=HelpFormatter
|
||||
).add_optional_num_argument(
|
||||
default=None,
|
||||
help=f"If passed a number: {get_first_doc_line(self._pool.__class__.pool_size.fset)} "
|
||||
f"If omitted: {get_first_doc_line(self._pool.__class__.pool_size.fget)}"
|
||||
help=f"If passed a number: {get_first_doc_line(cls.pool_size.fset)} "
|
||||
f"If omitted: {get_first_doc_line(cls.pool_size.fget)}"
|
||||
)
|
||||
self._add_command(CMD.IS_LOCKED, short_help=get_first_doc_line(cls.is_locked.fget))
|
||||
self._add_command(CMD.LOCK, short_help=get_first_doc_line(cls.lock))
|
||||
self._add_command(CMD.UNLOCK, short_help=get_first_doc_line(cls.unlock))
|
||||
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(cls.num_running.fget))
|
||||
self._add_command(CMD.NUM_CANCELLATIONS, short_help=get_first_doc_line(cls.num_cancellations.fget))
|
||||
self._add_command(CMD.NUM_ENDED, short_help=get_first_doc_line(cls.num_ended.fget))
|
||||
self._add_command(CMD.NUM_FINISHED, short_help=get_first_doc_line(cls.num_finished.fget))
|
||||
self._add_command(CMD.IS_FULL, short_help=get_first_doc_line(cls.is_full.fget))
|
||||
self._add_command(
|
||||
CMD.GET_GROUP_IDS, short_help=get_first_doc_line(cls.get_group_ids)
|
||||
).add_argument(
|
||||
'group_name',
|
||||
nargs='*',
|
||||
help="Must be a name of a task group that exists within the pool."
|
||||
)
|
||||
self._add_command(CMD.NUM_RUNNING, short_help=get_first_doc_line(self._pool.__class__.num_running.fget))
|
||||
|
||||
def _add_simple_commands(self) -> None:
|
||||
"""
|
||||
|
@ -20,10 +20,11 @@ Unittests for the `asyncio_taskpool.client` module.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest import IsolatedAsyncioTestCase, skipIf
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import client
|
||||
@ -171,6 +172,7 @@ class ControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
self.mock_print.assert_called_once_with("Disconnected from control server.")
|
||||
|
||||
|
||||
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
|
||||
class UnixControlClientTestCase(IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
85
tests/test_group_register.py
Normal file
85
tests/test_group_register.py
Normal file
@ -0,0 +1,85 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.group_register` module.
|
||||
"""
|
||||
|
||||
|
||||
from asyncio.locks import Lock
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import group_register
|
||||
|
||||
|
||||
FOO, BAR = 'foo', 'bar'
|
||||
|
||||
|
||||
class TaskGroupRegisterTestCase(IsolatedAsyncioTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.reg = group_register.TaskGroupRegister()
|
||||
|
||||
def test_init(self):
|
||||
ids = [FOO, BAR, 1, 2]
|
||||
reg = group_register.TaskGroupRegister(*ids)
|
||||
self.assertSetEqual(set(ids), reg._ids)
|
||||
self.assertIsInstance(reg._lock, Lock)
|
||||
|
||||
def test___contains__(self):
|
||||
self.reg._ids = {1, 2, 3}
|
||||
for i in self.reg._ids:
|
||||
self.assertTrue(i in self.reg)
|
||||
self.assertFalse(4 in self.reg)
|
||||
|
||||
@patch.object(group_register, 'iter', return_value=FOO)
|
||||
def test___iter__(self, mock_iter: MagicMock):
|
||||
self.assertEqual(FOO, self.reg.__iter__())
|
||||
mock_iter.assert_called_once_with(self.reg._ids)
|
||||
|
||||
def test___len__(self):
|
||||
self.reg._ids = [1, 2, 3, 4]
|
||||
self.assertEqual(4, len(self.reg))
|
||||
|
||||
def test_add(self):
|
||||
self.assertSetEqual(set(), self.reg._ids)
|
||||
self.assertIsNone(self.reg.add(123))
|
||||
self.assertSetEqual({123}, self.reg._ids)
|
||||
|
||||
def test_discard(self):
|
||||
self.reg._ids = {123}
|
||||
self.assertIsNone(self.reg.discard(0))
|
||||
self.assertIsNone(self.reg.discard(999))
|
||||
self.assertIsNone(self.reg.discard(123))
|
||||
self.assertSetEqual(set(), self.reg._ids)
|
||||
|
||||
async def test_acquire(self):
|
||||
self.assertFalse(self.reg._lock.locked())
|
||||
await self.reg.acquire()
|
||||
self.assertTrue(self.reg._lock.locked())
|
||||
|
||||
def test_release(self):
|
||||
self.reg._lock._locked = True
|
||||
self.assertTrue(self.reg._lock.locked())
|
||||
self.reg.release()
|
||||
self.assertFalse(self.reg._lock.locked())
|
||||
|
||||
async def test_contextmanager(self):
|
||||
self.assertFalse(self.reg._lock.locked())
|
||||
async with self.reg as nothing:
|
||||
self.assertIsNone(nothing)
|
||||
self.assertTrue(self.reg._lock.locked())
|
||||
self.assertFalse(self.reg._lock.locked())
|
@ -163,12 +163,12 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
||||
def test_is_full(self):
|
||||
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
|
||||
|
||||
def test_get_task_group_ids(self):
|
||||
def test_get_group_ids(self):
|
||||
group_name, ids = 'abcdef', [1, 2, 3]
|
||||
self.task_pool._task_groups[group_name] = MagicMock(__iter__=lambda _: iter(ids))
|
||||
self.assertEqual(set(ids), self.task_pool.get_task_group_ids(group_name))
|
||||
self.assertEqual(set(ids), self.task_pool.get_group_ids(group_name))
|
||||
with self.assertRaises(exceptions.InvalidGroupName):
|
||||
self.task_pool.get_task_group_ids('something else')
|
||||
self.task_pool.get_group_ids(group_name, 'something else')
|
||||
|
||||
async def test__check_start(self):
|
||||
self.task_pool._closed = True
|
||||
|
43
tests/test_queue_context.py
Normal file
43
tests/test_queue_context.py
Normal file
@ -0,0 +1,43 @@
|
||||
__author__ = "Daniil Fajnberg"
|
||||
__copyright__ = "Copyright © 2022 Daniil Fajnberg"
|
||||
__license__ = """GNU LGPLv3.0
|
||||
|
||||
This file is part of asyncio-taskpool.
|
||||
|
||||
asyncio-taskpool is free software: you can redistribute it and/or modify it under the terms of
|
||||
version 3.0 of the GNU Lesser General Public License as published by the Free Software Foundation.
|
||||
|
||||
asyncio-taskpool is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
||||
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
See the GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License along with asyncio-taskpool.
|
||||
If not, see <https://www.gnu.org/licenses/>."""
|
||||
|
||||
__doc__ = """
|
||||
Unittests for the `asyncio_taskpool.queue_context` module.
|
||||
"""
|
||||
|
||||
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from asyncio_taskpool.queue_context import Queue
|
||||
|
||||
|
||||
class QueueTestCase(IsolatedAsyncioTestCase):
|
||||
def test_item_processed(self):
|
||||
queue = Queue()
|
||||
queue._unfinished_tasks = 1000
|
||||
queue.item_processed()
|
||||
self.assertEqual(999, queue._unfinished_tasks)
|
||||
|
||||
@patch.object(Queue, 'item_processed')
|
||||
async def test_contextmanager(self, mock_item_processed: MagicMock):
|
||||
queue = Queue()
|
||||
item = 'foo'
|
||||
queue.put_nowait(item)
|
||||
async with queue as item_from_queue:
|
||||
self.assertEqual(item, item_from_queue)
|
||||
mock_item_processed.assert_not_called()
|
||||
mock_item_processed.assert_called_once_with()
|
@ -21,8 +21,9 @@ Unittests for the `asyncio_taskpool.server` module.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest import IsolatedAsyncioTestCase, skipIf
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from asyncio_taskpool import server
|
||||
@ -119,6 +120,7 @@ class ControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
mock_create_task.assert_called_once_with(mock_awaitable)
|
||||
|
||||
|
||||
@skipIf(os.name == 'nt', "No Unix sockets on Windows :(")
|
||||
class UnixControlServerTestCase(IsolatedAsyncioTestCase):
|
||||
log_lvl: int
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
## Minimal example for `SimpleTaskPool`
|
||||
|
||||
With a `SimpleTaskPool` the function to execute as well as the arguments with which to execute it must be defined during its initialization (and they cannot be changed later). The only control you have after initialization is how many of such tasks are being run.
|
||||
|
||||
The minimum required setup is a "worker" coroutine function that can do something asynchronously, and a main coroutine function that sets up the `SimpleTaskPool`, starts/stops the tasks as desired, and eventually awaits them all.
|
||||
|
||||
The following demo code enables full log output first for additional clarity. It is complete and should work as is.
|
||||
@ -32,12 +34,12 @@ async def work(n: int) -> None:
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
|
||||
pool = SimpleTaskPool(work, args=(5,)) # initializes the pool; no work is being done yet
|
||||
await pool.start(3) # launches work tasks 0, 1, and 2
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
await pool.start() # launches work task 3
|
||||
await asyncio.sleep(1.5) # lets the tasks work for a bit
|
||||
pool.stop(2) # cancels tasks 3 and 2
|
||||
pool.stop(2) # cancels tasks 3 and 2 (LIFO order)
|
||||
pool.lock() # required for the last line
|
||||
await pool.gather_and_close() # awaits all tasks, then flushes the pool
|
||||
|
||||
@ -114,19 +116,19 @@ async def other_work(a: int, b: int) -> None:
|
||||
async def main() -> None:
|
||||
# Initialize a new task pool instance and limit its size to 3 tasks.
|
||||
pool = TaskPool(3)
|
||||
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same positional arguments).
|
||||
# Queue up two tasks (IDs 0 and 1) to run concurrently (with the same keyword-arguments).
|
||||
print("> Called `apply`")
|
||||
await pool.apply(work, kwargs={'start': 100, 'stop': 200, 'step': 10}, num=2)
|
||||
# Let the tasks work for a bit.
|
||||
await asyncio.sleep(1.5)
|
||||
# Now, let us enqueue four more tasks (which will receive IDs 2, 3, 4, and 5), each created with different
|
||||
# positional arguments by using `starmap`, but have **no more than two of those** run concurrently.
|
||||
# positional arguments by using `starmap`, but we want no more than two of those to run concurrently.
|
||||
# Since we set our pool size to 3, and already have two tasks working within the pool,
|
||||
# only the first one of these will start immediately (and receive ID 2).
|
||||
# The second one will start (with ID 3), only once there is room in the pool,
|
||||
# which -- in this example -- will be the case after ID 2 ends.
|
||||
# Once there is room in the pool again, the third and fourth will each start (with IDs 4 and 5)
|
||||
# **only** once there is room in the pool **and** no more than one other task of these new ones is running.
|
||||
# only once there is room in the pool and no more than one other task of these new ones is running.
|
||||
args_list = [(0, 10), (10, 20), (20, 30), (30, 40)]
|
||||
await pool.starmap(other_work, args_list, group_size=2)
|
||||
print("> Called `starmap`")
|
||||
|
@ -23,7 +23,7 @@ Use the main CLI client to interface at the socket.
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from asyncio_taskpool import SimpleTaskPool, UnixControlServer
|
||||
from asyncio_taskpool import SimpleTaskPool, TCPControlServer
|
||||
from asyncio_taskpool.constants import PACKAGE_NAME
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ logging.getLogger(PACKAGE_NAME).addHandler(logging.StreamHandler())
|
||||
async def work(item: int) -> None:
|
||||
"""The non-blocking sleep simulates something like an I/O operation that can be done asynchronously."""
|
||||
await asyncio.sleep(1)
|
||||
print("worked on", item)
|
||||
print("worked on", item, flush=True)
|
||||
|
||||
|
||||
async def worker(q: asyncio.Queue) -> None:
|
||||
"""Simulates doing asynchronous work that takes a little bit of time to finish."""
|
||||
"""Simulates doing asynchronous work that takes a bit of time to finish."""
|
||||
# We only want the worker to stop, when its task is cancelled; therefore we start an infinite loop.
|
||||
while True:
|
||||
# We want to block here, until we can get the next item from the queue.
|
||||
@ -67,7 +67,7 @@ async def main() -> None:
|
||||
q.put_nowait(item)
|
||||
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
|
||||
await pool.start(3) # launches three worker tasks
|
||||
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
|
||||
control_server_task = await TCPControlServer(pool, host='127.0.0.1', port=9999).serve_forever()
|
||||
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
|
||||
await q.join()
|
||||
# Since we don't need any "work" done anymore, we can lock our control server by cancelling the task.
|
||||
@ -76,7 +76,7 @@ async def main() -> None:
|
||||
# we can now safely cancel their tasks.
|
||||
pool.lock()
|
||||
pool.stop_all()
|
||||
# Finally we allow for all tasks to do do their cleanup, if they need to do any, upon being cancelled.
|
||||
# Finally, we allow for all tasks to do their cleanup (as if they need to do any) upon being cancelled.
|
||||
# We block until they all return or raise an exception, but since we are not interested in any of their exceptions,
|
||||
# we just silently collect their exceptions along with their return values.
|
||||
await pool.gather_and_close(return_exceptions=True)
|
||||
|
Reference in New Issue
Block a user