generalized the decorator further, small refactoring; encountered a strange IDE issue
This commit is contained in:
parent
35c516fa0d
commit
aee9c96fa5
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Any
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Tuple, Any
|
||||
|
||||
from aiohttp.client import ClientSession
|
||||
|
||||
@ -9,16 +10,28 @@ LOGGER_NAME = 'webutils'
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
|
||||
def _get_param_idx_and_default(function: Callable, param_name: str) -> Tuple[int, Any]:
|
||||
params = signature(function).parameters
|
||||
return list(params.keys()).index(param_name), params[param_name].default
|
||||
|
||||
|
||||
# TODO: Figure out why PyCharm does not produce type hints to a function decorated with this,
|
||||
# when used without parentheses (like @in_async_session instead of @in_async_session()),
|
||||
# as soon as the `Callable` return type is added to the signature.
|
||||
def in_async_session(_func: Callable = None, *,
|
||||
session_kwargs: Dict[str, Any] = None, session_param_name: str = 'session') -> Callable:
|
||||
session_kwargs: Dict[str, Any] = None, session_param_name: str = 'session'):
|
||||
"""
|
||||
Useful decorator for any async function that uses the `aiohttp.ClientSession` to make requests.
|
||||
|
||||
Using this decorator allows the decorated function to have an optional session parameter,
|
||||
without the need to ensure proper initialization and closing of a session within the function itself.
|
||||
without the need to ensure proper initialization and closing of a session within the function body itself.
|
||||
|
||||
The wrapper has no effect, if a session object is passed into the function call, but if no session is passed,
|
||||
it initializes one, passes it into the function and ensures that it is closed in the end.
|
||||
The wrapper has no effect, if the default value is passed to the session parameter during the function call,
|
||||
otherwise it initializes a session, passes it into the function and ensures that it is closed in the end.
|
||||
|
||||
The wrapper distinguishes between positional and keyword arguments. This means that if the function is called by
|
||||
passing the default value as a positional argument, the temporary session is injected in its place as a positional
|
||||
argument also; if the default is passed as a keyword argument, the temporary session is injected as such instead.
|
||||
|
||||
Args:
|
||||
_func:
|
||||
@ -34,22 +47,39 @@ def in_async_session(_func: Callable = None, *,
|
||||
In case the decorated function's session parameter is named anything other than "session", that name should
|
||||
be provided here.
|
||||
"""
|
||||
if session_kwargs is None:
|
||||
session_kwargs = {}
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Using `functools.wraps` to preserve information about the actual function being decorated
|
||||
# More details: https://docs.python.org/3/library/functools.html#functools.wraps
|
||||
@wraps(function)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
"""The actual function wrapper that may perform the session initialization and closing."""
|
||||
temp_session = False
|
||||
if not any(isinstance(arg, ClientSession) for arg in args) and kwargs.get(session_param_name) is None:
|
||||
logger.debug("Starting temporary client session")
|
||||
kwargs[session_param_name] = ClientSession(**session_kwargs if session_kwargs is not None else {})
|
||||
temp_session = True
|
||||
temp_session = None
|
||||
session_param_idx, session_param_default = _get_param_idx_and_default(function, session_param_name)
|
||||
try:
|
||||
session = args[session_param_idx]
|
||||
except IndexError:
|
||||
# If we end up here, this means there was no positional session argument passed.
|
||||
session = kwargs.get(session_param_name)
|
||||
if session == session_param_default:
|
||||
logger.debug(f"Injecting temporary `ClientSession` as keyword argument `{session_param_name}` "
|
||||
f"into `{function.__name__}`")
|
||||
temp_session = ClientSession(**session_kwargs)
|
||||
kwargs[session_param_name] = temp_session
|
||||
else:
|
||||
if session == session_param_default:
|
||||
logger.debug(f"Injecting temporary `ClientSession` as positional argument {session_param_idx} "
|
||||
f"into `{function.__name__}`")
|
||||
temp_session = ClientSession(**session_kwargs)
|
||||
args = list(args)
|
||||
args[session_param_idx] = temp_session
|
||||
try:
|
||||
return await function(*args, **kwargs)
|
||||
finally:
|
||||
if temp_session:
|
||||
await kwargs[session_param_name].close()
|
||||
logger.debug("Temporary client session closed")
|
||||
await temp_session.close()
|
||||
logger.debug("Temporary `ClientSession` closed")
|
||||
return wrapper
|
||||
return decorator if _func is None else decorator(_func)
|
||||
|
Loading…
Reference in New Issue
Block a user