generated from daniil-berg/boilerplate-py
Add until_closed
method to pools
This commit is contained in:
parent
a4ecf39157
commit
dae883446a
@ -22,7 +22,7 @@ copyright = '2022 Daniil Fajnberg'
|
|||||||
author = 'Daniil Fajnberg'
|
author = 'Daniil Fajnberg'
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = '1.0.2'
|
release = '1.1.0'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = asyncio-taskpool
|
name = asyncio-taskpool
|
||||||
version = 1.0.3
|
version = 1.1.0
|
||||||
author = Daniil Fajnberg
|
author = Daniil Fajnberg
|
||||||
author_email = mail@daniil.fajnberg.de
|
author_email = mail@daniil.fajnberg.de
|
||||||
description = Dynamically manage pools of asyncio tasks
|
description = Dynamically manage pools of asyncio tasks
|
||||||
|
@ -31,7 +31,7 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
from asyncio.coroutines import iscoroutine, iscoroutinefunction
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from asyncio.tasks import Task, create_task, gather
|
from asyncio.tasks import Task, create_task, gather
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from math import inf
|
from math import inf
|
||||||
@ -72,7 +72,7 @@ class BaseTaskPool:
|
|||||||
|
|
||||||
# Initialize flags; immutably set the name.
|
# Initialize flags; immutably set the name.
|
||||||
self._locked: bool = False
|
self._locked: bool = False
|
||||||
self._closed: bool = False
|
self._closed: Event = Event()
|
||||||
self._name: str = name
|
self._name: str = name
|
||||||
|
|
||||||
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
# The following three dictionaries are the actual containers of the tasks controlled by the pool.
|
||||||
@ -221,7 +221,7 @@ class BaseTaskPool:
|
|||||||
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
raise exceptions.NotCoroutine(f"Not awaitable: {awaitable}")
|
||||||
if function and not iscoroutinefunction(function):
|
if function and not iscoroutinefunction(function):
|
||||||
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
raise exceptions.NotCoroutine(f"Not a coroutine function: {function}")
|
||||||
if self._closed:
|
if self._closed.is_set():
|
||||||
raise exceptions.PoolIsClosed("You must use another pool")
|
raise exceptions.PoolIsClosed("You must use another pool")
|
||||||
if self._locked and not ignore_lock:
|
if self._locked and not ignore_lock:
|
||||||
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
raise exceptions.PoolIsLocked("Cannot start new tasks")
|
||||||
@ -550,9 +550,10 @@ class BaseTaskPool:
|
|||||||
self._tasks_ended.clear()
|
self._tasks_ended.clear()
|
||||||
self._tasks_cancelled.clear()
|
self._tasks_cancelled.clear()
|
||||||
self._tasks_running.clear()
|
self._tasks_running.clear()
|
||||||
self._closed = True
|
self._closed.set()
|
||||||
# TODO: Turn the `_closed` attribute into an `Event` and add something like a `until_closed` method that will
|
|
||||||
# await it to allow blocking until a closing command comes from a server.
|
async def until_closed(self) -> bool:
|
||||||
|
return await self._closed.wait()
|
||||||
|
|
||||||
|
|
||||||
class TaskPool(BaseTaskPool):
|
class TaskPool(BaseTaskPool):
|
||||||
|
@ -19,7 +19,7 @@ Unittests for the `asyncio_taskpool.pool` module.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from asyncio.locks import Semaphore
|
from asyncio.locks import Event, Semaphore
|
||||||
from unittest import IsolatedAsyncioTestCase
|
from unittest import IsolatedAsyncioTestCase
|
||||||
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
from unittest.mock import PropertyMock, MagicMock, AsyncMock, patch, call
|
||||||
from typing import Type
|
from typing import Type
|
||||||
@ -83,7 +83,8 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertEqual(0, self.task_pool._num_started)
|
self.assertEqual(0, self.task_pool._num_started)
|
||||||
|
|
||||||
self.assertFalse(self.task_pool._locked)
|
self.assertFalse(self.task_pool._locked)
|
||||||
self.assertFalse(self.task_pool._closed)
|
self.assertIsInstance(self.task_pool._closed, Event)
|
||||||
|
self.assertFalse(self.task_pool._closed.is_set())
|
||||||
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
self.assertEqual(self.TEST_POOL_NAME, self.task_pool._name)
|
||||||
|
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
@ -162,7 +163,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool.get_group_ids(group_name, 'something else')
|
self.task_pool.get_group_ids(group_name, 'something else')
|
||||||
|
|
||||||
async def test__check_start(self):
|
async def test__check_start(self):
|
||||||
self.task_pool._closed = True
|
self.task_pool._closed.set()
|
||||||
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
mock_coroutine, mock_coroutine_function = AsyncMock()(), AsyncMock()
|
||||||
try:
|
try:
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
@ -175,7 +176,7 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
self.task_pool._check_start(awaitable=None, function=mock_coroutine)
|
||||||
with self.assertRaises(exceptions.PoolIsClosed):
|
with self.assertRaises(exceptions.PoolIsClosed):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None)
|
||||||
self.task_pool._closed = False
|
self.task_pool._closed.clear()
|
||||||
self.task_pool._locked = True
|
self.task_pool._locked = True
|
||||||
with self.assertRaises(exceptions.PoolIsLocked):
|
with self.assertRaises(exceptions.PoolIsLocked):
|
||||||
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
self.task_pool._check_start(awaitable=mock_coroutine, function=None, ignore_lock=False)
|
||||||
@ -461,7 +462,13 @@ class BaseTaskPoolTestCase(CommonTestCase):
|
|||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_ended)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_cancelled)
|
||||||
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
self.assertDictEqual(EMPTY_DICT, self.task_pool._tasks_running)
|
||||||
self.assertTrue(self.task_pool._closed)
|
self.assertTrue(self.task_pool._closed.is_set())
|
||||||
|
|
||||||
|
async def test_until_closed(self):
|
||||||
|
self.task_pool._closed = MagicMock(wait=AsyncMock(return_value=FOO))
|
||||||
|
output = await self.task_pool.until_closed()
|
||||||
|
self.assertEqual(FOO, output)
|
||||||
|
self.task_pool._closed.wait.assert_awaited_once_with()
|
||||||
|
|
||||||
|
|
||||||
class TaskPoolTestCase(CommonTestCase):
|
class TaskPoolTestCase(CommonTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user