Overload @post_load decorator to retain the function type

This commit is contained in:
Daniil Fajnberg 2023-03-10 13:04:00 +01:00
parent 32ffea3b4e
commit f75679f6b2
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
3 changed files with 72 additions and 1 deletions

View File

@ -0,0 +1,41 @@
from collections.abc import Callable
from typing import Any, Optional, TypeVar, overload
from typing_extensions import ParamSpec
from marshmallow.decorators import post_load as _post_load
_R = TypeVar("_R")
_P = ParamSpec("_P")
@overload
def post_load(
fn: Callable[_P, _R],
pass_many: bool = False,
pass_original: bool = False,
) -> Callable[_P, _R]:
...
@overload
def post_load(
fn: None = None,
pass_many: bool = False,
pass_original: bool = False,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
...
def post_load(
fn: Optional[Callable[..., Any]] = None,
pass_many: bool = False,
pass_original: bool = False,
) -> Callable[..., Any]:
"""
Typed overload of the original `marshmallow.post_load` decorator function.
Generic to ensure that the decorated function retains its type.
Runtime behavior is unchanged.
"""
return _post_load(fn, pass_many=pass_many, pass_original=pass_original)

View File

@ -12,7 +12,9 @@ class GenericInsightMixinTestCase(TestCase):
mock_super.return_value = MagicMock(__init_subclass__=mock_super_meth)
# Should be `None` by default:
self.assertIsNone(_util.GenericInsightMixin._type_arg) # type: ignore[misc]
self.assertIsNone(
_util.GenericInsightMixin._type_arg # type: ignore[misc]
)
# If the mixin type argument was not specified (still generic),
# ensure that the attribute remains `None` on the subclass:

28
tests/test_decorators.py Normal file
View File

@ -0,0 +1,28 @@
from collections.abc import Callable
from unittest import TestCase
from unittest.mock import MagicMock, patch
from marshmallow_generic import decorators
class DecoratorsTestCase(TestCase):
@patch.object(decorators, "_post_load")
def test_post_load(self, mock_original_post_load: MagicMock) -> None:
mock_original_post_load.return_value = expected_output = object()
def test_function(x: int) -> str:
return str(x)
pass_many, pass_original = MagicMock(), MagicMock()
# Explicit annotation to possibly catch mypy errors:
output: Callable[[int], str] = decorators.post_load(
test_function,
pass_many=pass_many,
pass_original=pass_original,
)
self.assertIs(expected_output, output)
mock_original_post_load.assert_called_once_with(
test_function,
pass_many=pass_many,
pass_original=pass_original,
)