webutils-df/src/webutils_df/util.py

183 lines
10 KiB
Python

import logging
import asyncio
from functools import wraps
from inspect import signature
from math import inf
from timeit import default_timer
from typing import Callable, Awaitable, Dict, Tuple, Sequence, Any, Type, Union, TypeVar
from aiohttp.client import ClientSession
LOGGER_NAME = 'webutils'
logger = logging.getLogger(LOGGER_NAME)
AsyncFunction = TypeVar('AsyncFunction')
AttemptCallbackT = Callable[[AsyncFunction, Exception, int, float, tuple, dict], Awaitable[None]]
def _get_param_idx_and_default(function: Callable, param_name: str) -> Tuple[int, Any]:
params = signature(function).parameters
return list(params.keys()).index(param_name), params[param_name].default
def in_async_session(_func: AsyncFunction = None, *,
session_kwargs: Dict[str, Any] = None, session_param_name: str = 'session') -> AsyncFunction:
"""
Useful decorator for any async function that uses the `aiohttp.ClientSession` to make requests.
Using this decorator allows the decorated function to have an optional session parameter,
without the need to ensure proper initialization and closing of a session within the function body itself.
The wrapper has no effect, if the default value is passed to the session parameter during the function call,
otherwise it initializes a session, passes it into the function and ensures that it is closed in the end.
The wrapper distinguishes between positional and keyword arguments. This means that if the function is called by
passing the default value as a positional argument, the temporary session is injected in its place as a positional
argument also; if the default is passed as a keyword argument, the temporary session is injected as such instead.
Args:
_func:
If this decorator is used *with any* arguments, this will always be the decorated function itself.
This is a trick to allow the decorator to be used with as well as without arguments, i.e. in the form
`@in_async_session` or `@in_async_session(...)`.
session_kwargs (optional):
If passed a dictionary, it will be unpacked and passed as keyword arguments into the `ClientSession`
constructor, if and only if the decorator actually handles session initialization/closing,
i.e. only when the function is called **without** passing a session object into it.
session_param_name (optional):
The name of the decorated function's parameter that should be passed the session object as an argument.
In case the decorated function's session parameter is named anything other than "session", that name should
be provided here.
"""
if session_kwargs is None:
session_kwargs = {}
def decorator(function: AsyncFunction) -> AsyncFunction:
# Using `functools.wraps` to preserve information about the actual function being decorated
# More details: https://docs.python.org/3/library/functools.html#functools.wraps
@wraps(function)
async def wrapper(*args, **kwargs) -> Any:
"""The actual function wrapper that may perform the session initialization and closing."""
temp_session = None
session_param_idx, session_param_default = _get_param_idx_and_default(function, session_param_name)
try:
session = args[session_param_idx]
except IndexError:
# If we end up here, this means there was no positional session argument passed.
session = kwargs.get(session_param_name)
if session == session_param_default:
logger.debug(f"Injecting temporary `ClientSession` as keyword argument `{session_param_name}` "
f"into `{function.__name__}`")
temp_session = ClientSession(**session_kwargs)
kwargs[session_param_name] = temp_session
else:
if session == session_param_default:
logger.debug(f"Injecting temporary `ClientSession` as positional argument {session_param_idx} "
f"into `{function.__name__}`")
temp_session = ClientSession(**session_kwargs)
args = list(args)
args[session_param_idx] = temp_session
try:
return await function(*args, **kwargs)
finally:
if temp_session:
await temp_session.close()
logger.debug("Temporary `ClientSession` closed")
return wrapper
return decorator if _func is None else decorator(_func)
def attempt(_func: AsyncFunction = None, *,
exception: Union[Type[Exception], Sequence[Type[Exception]]] = Exception,
max_attempts: float = inf,
timeout_seconds: float = inf,
seconds_between: float = 0,
callback: AttemptCallbackT = None) -> AsyncFunction:
"""
Decorator allowing an async function to be called repeatedly, if previous attempts cause specific exceptions.
Note: If no limiting arguments are passed to the decorator, the decorated function **will** be called repeatedly in
a potentially infinite loop, as long as it keeps throwing an exception.
Args:
_func:
Control parameter; allows using the decorator with or without arguments.
If this decorator is used *with any* arguments, this will always be the decorated function itself.
exception (optional):
An `Exception` (sub-)class or a sequence thereof; a failed call of the decorated function will only be
repeated if it fails with a matching exception. Defaults to `Exception`, i.e. any exception.
max_attempts (optional):
The maximum number of (re-)attempts at calling the decorated function; if it is called `max_attempts` times
and fails, the exception will be propagated. The number of attempts is unlimited by default.
timeout_seconds (optional):
Defines the cutoff time (in seconds) for the entirety of attempts at executing the decorated function;
if the attempts take longer in total and fail, the exception will be propagated. No timeout by default.
seconds_between (optional):
Sets a sleep interval (in seconds) between each attempt to call the decorated function. Defaults to 0.
callback (optional):
If passed an async function (with matching parameters), a failed **and caught** attempt will call it with
the following positional arguments (in that order):
- the decorated async function itself
- the exception class encountered and caught
- the total number of failed attempts up to that point
- the `seconds_between` argument
- positional and keyword arguments (as tuple and dictionary respectively) passed to the decorated function
Raises:
Any exceptions that do **not** match those passed to the `exception` parameter are immediately propagated.
Those that were specified in `exception` are propagated when `max_attempts` or `timeout_seconds` are reached.
"""
def decorator(function: AsyncFunction) -> AsyncFunction:
# Using `functools.wraps` to preserve information about the actual function being decorated
# More details: https://docs.python.org/3/library/functools.html#functools.wraps
@wraps(function)
async def wrapper(*args, **kwargs) -> Any:
start, failed_attempts = default_timer(), 0
while True:
try:
return await function(*args, **kwargs)
except exception as e:
failed_attempts += 1
if default_timer() - start >= timeout_seconds or failed_attempts >= max_attempts:
raise e
if callback:
await callback(function, e, failed_attempts, seconds_between, args, kwargs)
await asyncio.sleep(seconds_between)
return wrapper
return decorator if _func is None else decorator(_func)
async def log_failed_attempt(f: Callable, e: Exception, n: int, delay: float, args: tuple, kwargs: dict) -> None:
"""Intended to be the prototypical `callback` argument for the `attempt` decorator above."""
arg_str = f"{', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={repr(v)}' for k, v in kwargs.items())}"
logger.warning(f"Attempt {n} at {f.__name__}({arg_str}) failed with {repr(e)}; retrying after {delay} seconds.")
async def gather_in_batches(batch_size: int, *aws: Awaitable, return_exceptions: bool = False) -> list:
"""
Simple extension of the `asyncio.gather` function to make it easy to run awaitable objects in concurrent batches.
(see: https://docs.python.org/3/library/asyncio-task.html#asyncio.gather)
A batch is run concurrently using `gather`, while the calls to `gather` for each batch are done sequentially.
This can be useful if for example there is a very large number of requests to a single website to be made
and you want to make use of concurrency, but not by performing all of them at the same time.
Args:
batch_size:
The maximum number of awaitable objects to run concurrently at any given moment.
If this is higher than the number of awaitable objects, this is equivalent to a single `gather` call.
Every batch will have this number of concurrent runs, but the last batch can obviously be smaller.
aws:
The awaitable objects/coroutines; subsets of these are passed to `gather` always maintaining overall order
return_exceptions (optional):
Passed into each `gather` call
Returns:
The aggregate list of returned values for all awaitable objects.
The order of result values corresponds to the order of `aws`.
"""
results = []
for idx in range(0, len(aws), batch_size):
results.extend(await asyncio.gather(*aws[idx:idx + batch_size], return_exceptions=return_exceptions))
return results