implemented working pool size limit; adjusted tests and examples; small renaming

This commit is contained in:
Daniil Fajnberg 2022-02-06 13:08:39 +01:00
parent a68e61dfa7
commit 2f0b08edf0
6 changed files with 98 additions and 125 deletions

View File

@ -2,13 +2,13 @@ import logging
from asyncio import gather
from asyncio.coroutines import iscoroutinefunction
from asyncio.exceptions import CancelledError
from asyncio.locks import Event
from asyncio.locks import Event, Semaphore
from asyncio.tasks import Task, create_task
from math import inf
from typing import Any, Awaitable, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from . import exceptions
from .types import ArgsT, KwArgsT, CoroutineFunc, FinalCallbackT, CancelCallbackT
from .types import ArgsT, KwArgsT, CoroutineFunc, EndCallbackT, CancelCallbackT
log = logging.getLogger(__name__)
@ -26,7 +26,8 @@ class BaseTaskPool:
def __init__(self, pool_size: int = inf, name: str = None) -> None:
"""Initializes the necessary internal attributes and adds the new pool to the general pools list."""
self.pool_size: int = pool_size
self._enough_room: Semaphore = Semaphore()
self.pool_size = pool_size
self._open: bool = True
self._counter: int = 0
self._running: Dict[int, Task] = {}
@ -37,13 +38,22 @@ class BaseTaskPool:
self._name: str = name
self._all_tasks_known_flag: Event = Event()
self._all_tasks_known_flag.set()
self._more_allowed_flag: Event = Event()
self._check_more_allowed()
log.debug("%s initialized", str(self))
def __str__(self) -> str:
return f'{self.__class__.__name__}-{self._name or self._idx}'
@property
def pool_size(self) -> int:
return self._pool_size
@pool_size.setter
def pool_size(self, value: int) -> None:
if value < 0:
raise ValueError("Pool size can not be less than 0")
self._enough_room._value = value
self._pool_size = value
@property
def is_open(self) -> bool:
"""Returns `True` if more the pool has not been closed yet."""
@ -53,7 +63,7 @@ class BaseTaskPool:
def num_running(self) -> int:
"""
Returns the number of tasks in the pool that are (at that moment) still running.
At the moment a task's final callback function is fired, it is no longer considered to be running.
At the moment a task's `end_callback` is fired, it is no longer considered to be running.
"""
return len(self._running)
@ -61,7 +71,7 @@ class BaseTaskPool:
def num_cancelled(self) -> int:
"""
Returns the number of tasks in the pool that have been cancelled through the pool (up until that moment).
At the moment a task's cancel callback function is fired, it is considered cancelled and no longer running.
At the moment a task's `cancel_callback` is fired, it is considered cancelled and no longer running.
"""
return len(self._cancelled)
@ -69,9 +79,9 @@ class BaseTaskPool:
def num_ended(self) -> int:
"""
Returns the number of tasks started through the pool that have stopped running (up until that moment).
At the moment a task's final callback function is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its cancel callback function has
returned, does it then actually end.
At the moment a task's `end_callback` is fired, it is considered ended.
When a task is cancelled, it is not immediately considered ended; only after its `cancel_callback` has returned,
does it then actually end.
"""
return len(self._ended)
@ -88,18 +98,9 @@ class BaseTaskPool:
Returns `False` only if (at that moment) the number of running tasks is below the pool's specified size.
When the pool is full, any call to start a new task within it will block.
"""
return not self._more_allowed_flag.is_set()
def _check_more_allowed(self) -> None:
"""
Sets or clears the internal event flag signalling whether or not the pool is full (i.e. whether more tasks can
be started), if the current state of the pool demands it.
"""
if self.is_full and self.num_running < self.pool_size:
self._more_allowed_flag.set()
elif not self.is_full and self.num_running >= self.pool_size:
self._more_allowed_flag.clear()
return self._enough_room.locked()
# TODO: Consider adding task group names
def _task_name(self, task_id: int) -> str:
"""Returns a standardized name for a task with a specific `task_id`."""
return f'{self}_Task-{task_id}'
@ -110,20 +111,20 @@ class BaseTaskPool:
assert task is not None
self._cancelled[task_id] = task
self._ending += 1
await _execute_function(custom_callback, args=(task_id, ))
log.debug("Cancelled %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
async def _end_task(self, task_id: int, custom_callback: FinalCallbackT = None) -> None:
async def _end_task(self, task_id: int, custom_callback: EndCallbackT = None) -> None:
task = self._running.pop(task_id, None)
if task is None:
task = self._cancelled[task_id]
self._ending -= 1
self._ended[task_id] = task
await _execute_function(custom_callback, args=(task_id, ))
self._check_more_allowed()
self._enough_room.release()
log.info("Ended %s", self._task_name(task_id))
await _execute_function(custom_callback, args=(task_id, ))
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, final_callback: FinalCallbackT = None,
async def _task_wrapper(self, awaitable: Awaitable, task_id: int, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> Any:
log.info("Started %s", self._task_name(task_id))
try:
@ -131,20 +132,23 @@ class BaseTaskPool:
except CancelledError:
await self._cancel_task(task_id, custom_callback=cancel_callback)
finally:
await self._end_task(task_id, custom_callback=final_callback)
await self._end_task(task_id, custom_callback=end_callback)
def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, final_callback: FinalCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
async def _start_task(self, awaitable: Awaitable, ignore_closed: bool = False, end_callback: EndCallbackT = None,
cancel_callback: CancelCallbackT = None) -> int:
if not (self.is_open or ignore_closed):
raise exceptions.PoolIsClosed("Cannot start new tasks")
# TODO: Implement this (and the dependent user-facing methods) as async to wait for room in the pool
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, final_callback, cancel_callback),
name=self._task_name(task_id)
)
self._check_more_allowed()
await self._enough_room.acquire()
try:
task_id = self._counter
self._counter += 1
self._running[task_id] = create_task(
self._task_wrapper(awaitable, task_id, end_callback, cancel_callback),
name=self._task_name(task_id)
)
except Exception as e:
self._enough_room.release()
raise e
return task_id
def _cancel_one(self, task_id: int, msg: str = None) -> None:
@ -180,15 +184,15 @@ class BaseTaskPool:
class TaskPool(BaseTaskPool):
def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
async def _apply_one(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> int:
if kwargs is None:
kwargs = {}
return self._start_task(func(*args, **kwargs), final_callback=final_callback, cancel_callback=cancel_callback)
return await self._start_task(func(*args, **kwargs), end_callback=end_callback, cancel_callback=cancel_callback)
def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(self._apply_one(func, args, kwargs, final_callback, cancel_callback) for _ in range(num))
async def apply(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None, num: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> Tuple[int]:
return tuple(await self._apply_one(func, args, kwargs, end_callback, cancel_callback) for _ in range(num))
@staticmethod
def _get_next_coroutine(func: CoroutineFunc, args_iter: Iterator[Any], arg_stars: int = 0) -> Optional[Awaitable]:
@ -204,54 +208,54 @@ class TaskPool(BaseTaskPool):
return func(**arg)
raise ValueError
def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
async def _map(self, func: CoroutineFunc, args_iter: ArgsT, arg_stars: int = 0, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
if self._all_tasks_known_flag.is_set():
self._all_tasks_known_flag.clear()
args_iter = iter(args_iter)
def _start_next_coroutine() -> bool:
async def _start_next_coroutine() -> bool:
cor = self._get_next_coroutine(func, args_iter, arg_stars)
if cor is None:
self._all_tasks_known_flag.set()
return True
self._start_task(cor, ignore_closed=True, final_callback=_start_next, cancel_callback=cancel_callback)
await self._start_task(cor, ignore_closed=True, end_callback=_start_next, cancel_callback=cancel_callback)
return False
async def _start_next(task_id: int) -> None:
await _execute_function(final_callback, args=(task_id, ))
_start_next_coroutine()
await _start_next_coroutine()
await _execute_function(end_callback, args=(task_id, ))
for _ in range(num_tasks):
reached_end = _start_next_coroutine()
reached_end = await _start_next_coroutine()
if reached_end:
break
def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
async def map(self, func: CoroutineFunc, args_iter: ArgsT, num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=0, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
async def starmap(self, func: CoroutineFunc, args_iter: Iterable[ArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, args_iter, arg_stars=1, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
final_callback=final_callback, cancel_callback=cancel_callback)
async def doublestarmap(self, func: CoroutineFunc, kwargs_iter: Iterable[KwArgsT], num_tasks: int = 1,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None) -> None:
await self._map(func, kwargs_iter, arg_stars=2, num_tasks=num_tasks,
end_callback=end_callback, cancel_callback=cancel_callback)
class SimpleTaskPool(BaseTaskPool):
def __init__(self, func: CoroutineFunc, args: ArgsT = (), kwargs: KwArgsT = None,
final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None,
end_callback: EndCallbackT = None, cancel_callback: CancelCallbackT = None,
name: str = None) -> None:
self._func: CoroutineFunc = func
self._args: ArgsT = args
self._kwargs: KwArgsT = kwargs if kwargs is not None else {}
self._final_callback: FinalCallbackT = final_callback
self._end_callback: EndCallbackT = end_callback
self._cancel_callback: CancelCallbackT = cancel_callback
super().__init__(name=name)
@ -263,12 +267,12 @@ class SimpleTaskPool(BaseTaskPool):
def size(self) -> int:
return self.num_running
def _start_one(self) -> int:
return self._start_task(self._func(*self._args, **self._kwargs),
final_callback=self._final_callback, cancel_callback=self._cancel_callback)
async def _start_one(self) -> int:
return await self._start_task(self._func(*self._args, **self._kwargs),
end_callback=self._end_callback, cancel_callback=self._cancel_callback)
def start(self, num: int = 1) -> List[int]:
return [self._start_one() for _ in range(num)]
async def start(self, num: int = 1) -> List[int]:
return [await self._start_one() for _ in range(num)]
def stop(self, num: int = 1) -> List[int]:
num = min(num, self.size)

View File

@ -45,11 +45,11 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
self._server_kwargs = server_kwargs
self._server: Optional[AbstractServer] = None
def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
async def _start_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
num = 1
log.debug("%s requests starting %s %s", self.client_class.__name__, num, tasks_str(num))
writer.write(str(self._pool.start(num)).encode())
writer.write(str(await self._pool.start(num)).encode())
def _stop_tasks(self, writer: StreamWriter, num: int = None) -> None:
if num is None:
@ -78,7 +78,7 @@ class ControlServer(ABC): # TODO: Implement interface for normal TaskPool insta
break
cmd, arg = get_cmd_arg(msg)
if cmd == constants.CMD_START:
self._start_tasks(writer, arg)
await self._start_tasks(writer, arg)
elif cmd == constants.CMD_STOP:
self._stop_tasks(writer, arg)
elif cmd == constants.CMD_STOP_ALL:

View File

@ -5,7 +5,7 @@ from typing import Any, Awaitable, Callable, Iterable, Mapping, Tuple, Union
ArgsT = Iterable[Any]
KwArgsT = Mapping[str, Any]
CoroutineFunc = Callable[[...], Awaitable[Any]]
FinalCallbackT = Callable
EndCallbackT = Callable
CancelCallbackT = Callable
ClientConnT = Union[Tuple[StreamReader, StreamWriter], Tuple[None, None]]

View File

@ -1,6 +1,6 @@
import asyncio
from unittest import TestCase
from unittest.mock import MagicMock, PropertyMock, patch, call
from unittest.mock import MagicMock, PropertyMock, patch
from asyncio_taskpool import pool
@ -14,10 +14,10 @@ class BaseTaskPoolTestCase(TestCase):
# These three methods are called during initialization, so we mock them by default during setup
self._add_pool_patcher = patch.object(pool.BaseTaskPool, '_add_pool')
self._check_more_allowed_patcher = patch.object(pool.BaseTaskPool, '_check_more_allowed')
self.pool_size_patcher = patch.object(pool.BaseTaskPool, 'pool_size', new_callable=PropertyMock)
self.__str___patcher = patch.object(pool.BaseTaskPool, '__str__')
self.mock__add_pool = self._add_pool_patcher.start()
self.mock__check_more_allowed = self._check_more_allowed_patcher.start()
self.mock_pool_size = self.pool_size_patcher.start()
self.mock___str__ = self.__str___patcher.start()
self.mock__add_pool.return_value = self.mock_idx = 123
self.mock___str__.return_value = self.mock_str = 'foobar'
@ -29,7 +29,7 @@ class BaseTaskPoolTestCase(TestCase):
def tearDown(self) -> None:
setattr(pool.TaskPool, '_pools', self._pools)
self._add_pool_patcher.stop()
self._check_more_allowed_patcher.stop()
self.pool_size_patcher.stop()
self.__str___patcher.stop()
def test__add_pool(self):
@ -40,7 +40,7 @@ class BaseTaskPoolTestCase(TestCase):
self.assertListEqual([self.task_pool], getattr(pool.TaskPool, '_pools'))
def test_init(self):
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
self.assertIsInstance(self.task_pool._enough_room, asyncio.locks.Semaphore)
self.assertTrue(self.task_pool._open)
self.assertEqual(0, self.task_pool._counter)
self.assertDictEqual(EMPTY_DICT, self.task_pool._running)
@ -51,9 +51,8 @@ class BaseTaskPoolTestCase(TestCase):
self.assertEqual(self.test_pool_name, self.task_pool._name)
self.assertIsInstance(self.task_pool._all_tasks_known_flag, asyncio.locks.Event)
self.assertTrue(self.task_pool._all_tasks_known_flag.is_set())
self.assertIsInstance(self.task_pool._more_allowed_flag, asyncio.locks.Event)
self.mock__add_pool.assert_called_once_with(self.task_pool)
self.mock__check_more_allowed.assert_called_once_with()
self.mock_pool_size.assert_called_once_with(self.test_pool_size)
self.mock___str__.assert_called_once_with()
def test___str__(self):
@ -64,6 +63,17 @@ class BaseTaskPoolTestCase(TestCase):
expected_str = f'{pool.BaseTaskPool.__name__}-{self.task_pool._idx}'
self.assertEqual(expected_str, str(self.task_pool))
def test_pool_size(self):
self.pool_size_patcher.stop()
self.task_pool._pool_size = self.test_pool_size
self.assertEqual(self.test_pool_size, self.task_pool.pool_size)
with self.assertRaises(ValueError):
self.task_pool.pool_size = -1
self.task_pool.pool_size = new_size = 69
self.assertEqual(new_size, self.task_pool._pool_size)
def test_is_open(self):
self.task_pool._open = foo = 'foo'
self.assertEqual(foo, self.task_pool.is_open)
@ -91,48 +101,7 @@ class BaseTaskPoolTestCase(TestCase):
mock_num_ended.assert_called_once_with()
def test_is_full(self):
self.assertEqual(not self.task_pool._more_allowed_flag.is_set(), self.task_pool.is_full)
@patch.object(pool.BaseTaskPool, 'num_running', new_callable=PropertyMock)
@patch.object(pool.BaseTaskPool, 'is_full', new_callable=PropertyMock)
def test__check_more_allowed(self, mock_is_full: MagicMock, mock_num_running: MagicMock):
def reset_mocks():
mock_is_full.reset_mock()
mock_num_running.reset_mock()
self._check_more_allowed_patcher.stop()
# Just reaching limit, we expect flag to become unset:
mock_is_full.return_value = False
mock_num_running.return_value = 420
self.task_pool._more_allowed_flag.clear()
self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
reset_mocks()
# Already at limit, we expect nothing to change:
mock_is_full.return_value = True
self.task_pool._check_more_allowed()
self.assertFalse(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
reset_mocks()
# Just finished a task, we expect flag to become set:
mock_num_running.return_value = 419
self.task_pool._check_more_allowed()
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_called_once_with()
mock_num_running.assert_called_once_with()
reset_mocks()
# In this state we expect the flag to remain unchanged change:
mock_is_full.return_value = False
self.task_pool._check_more_allowed()
self.assertTrue(self.task_pool._more_allowed_flag.is_set())
mock_is_full.assert_has_calls([call(), call()])
mock_num_running.assert_called_once_with()
self.assertEqual(self.task_pool._enough_room.locked(), self.task_pool.is_full)
def test__task_name(self):
i = 123

View File

@ -33,9 +33,9 @@ async def work(n: int) -> None:
async def main() -> None:
pool = SimpleTaskPool(work, (5,)) # initializes the pool; no work is being done yet
pool.start(3) # launches work tasks 0, 1, and 2
await pool.start(3) # launches work tasks 0, 1, and 2
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.start() # launches work task 3
await pool.start() # launches work task 3
await asyncio.sleep(1.5) # lets the tasks work for a bit
pool.stop(2) # cancels tasks 3 and 2
pool.close() # required for the last line

View File

@ -44,7 +44,7 @@ async def main() -> None:
for item in range(100):
q.put_nowait(item)
pool = SimpleTaskPool(worker, (q,)) # initializes the pool
pool.start(3) # launches three worker tasks
await pool.start(3) # launches three worker tasks
control_server_task = await UnixControlServer(pool, path='/tmp/py_asyncio_taskpool.sock').serve_forever()
# We block until `.task_done()` has been called once by our workers for every item placed into the queue.
await q.join()