generalized the decorator further, small refactoring; encountered a strange IDE issue

This commit is contained in:
Daniil Fajnberg 2021-11-28 19:43:47 +01:00
parent 35c516fa0d
commit aee9c96fa5
1 changed files with 43 additions and 13 deletions

View File

@ -1,6 +1,7 @@
import logging
from functools import wraps
from typing import Callable, Dict, Any
from inspect import signature
from typing import Callable, Dict, Tuple, Any
from aiohttp.client import ClientSession
@ -9,16 +10,28 @@ LOGGER_NAME = 'webutils'
logger = logging.getLogger(LOGGER_NAME)
def _get_param_idx_and_default(function: Callable, param_name: str) -> Tuple[int, Any]:
params = signature(function).parameters
return list(params.keys()).index(param_name), params[param_name].default
# TODO: Figure out why PyCharm does not produce type hints to a function decorated with this,
# when used without parentheses (like @in_async_session instead of @in_async_session()),
# as soon as the `Callable` return type is added to the signature.
def in_async_session(_func: Callable = None, *,
session_kwargs: Dict[str, Any] = None, session_param_name: str = 'session') -> Callable:
session_kwargs: Dict[str, Any] = None, session_param_name: str = 'session'):
"""
Useful decorator for any async function that uses the `aiohttp.ClientSession` to make requests.
Using this decorator allows the decorated function to have an optional session parameter,
without the need to ensure proper initialization and closing of a session within the function itself.
without the need to ensure proper initialization and closing of a session within the function body itself.
The wrapper has no effect, if a session object is passed into the function call, but if no session is passed,
it initializes one, passes it into the function and ensures that it is closed in the end.
The wrapper has no effect, if the default value is passed to the session parameter during the function call,
otherwise it initializes a session, passes it into the function and ensures that it is closed in the end.
The wrapper distinguishes between positional and keyword arguments. This means that if the function is called by
passing the default value as a positional argument, the temporary session is injected in its place as a positional
argument also; if the default is passed as a keyword argument, the temporary session is injected as such instead.
Args:
_func:
@ -34,22 +47,39 @@ def in_async_session(_func: Callable = None, *,
In case the decorated function's session parameter is named anything other than "session", that name should
be provided here.
"""
if session_kwargs is None:
session_kwargs = {}
def decorator(function: Callable) -> Callable:
# Using `functools.wraps` to preserve information about the actual function being decorated
# More details: https://docs.python.org/3/library/functools.html#functools.wraps
@wraps(function)
async def wrapper(*args, **kwargs):
async def wrapper(*args, **kwargs) -> Any:
"""The actual function wrapper that may perform the session initialization and closing."""
temp_session = False
if not any(isinstance(arg, ClientSession) for arg in args) and kwargs.get(session_param_name) is None:
logger.debug("Starting temporary client session")
kwargs[session_param_name] = ClientSession(**session_kwargs if session_kwargs is not None else {})
temp_session = True
temp_session = None
session_param_idx, session_param_default = _get_param_idx_and_default(function, session_param_name)
try:
session = args[session_param_idx]
except IndexError:
# If we end up here, this means there was no positional session argument passed.
session = kwargs.get(session_param_name)
if session == session_param_default:
logger.debug(f"Injecting temporary `ClientSession` as keyword argument `{session_param_name}` "
f"into `{function.__name__}`")
temp_session = ClientSession(**session_kwargs)
kwargs[session_param_name] = temp_session
else:
if session == session_param_default:
logger.debug(f"Injecting temporary `ClientSession` as positional argument {session_param_idx} "
f"into `{function.__name__}`")
temp_session = ClientSession(**session_kwargs)
args = list(args)
args[session_param_idx] = temp_session
try:
return await function(*args, **kwargs)
finally:
if temp_session:
await kwargs[session_param_name].close()
logger.debug("Temporary client session closed")
await temp_session.close()
logger.debug("Temporary `ClientSession` closed")
return wrapper
return decorator if _func is None else decorator(_func)