From 84fa2d2cd9a9e183dd3161e8e53698844384f1e5 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 12 Mar 2023 18:04:28 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20Expand=20generic=20utility=20mix?= =?UTF-8?q?in=20to=20handle=20up=20to=205=20type=20arguments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/marshmallow_generic/_util.py | 97 ++++++++++++++++++++++++++----- src/marshmallow_generic/schema.py | 6 +- tests/test__util.py | 51 ++++++++++++---- tests/test_schema.py | 2 +- 4 files changed, 128 insertions(+), 28 deletions(-) diff --git a/src/marshmallow_generic/_util.py b/src/marshmallow_generic/_util.py index 3426876..ac6686a 100644 --- a/src/marshmallow_generic/_util.py +++ b/src/marshmallow_generic/_util.py @@ -1,10 +1,28 @@ -from typing import Any, Generic, Optional, TypeVar, get_args, get_origin +from typing import ( + Any, + Generic, + Literal, + Optional, + TypeVar, + Union, + get_args, + get_origin, + overload, +) -_T = TypeVar("_T") +_T0 = TypeVar("_T0") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") +_T4 = TypeVar("_T4") -class GenericInsightMixin(Generic[_T]): - _type_arg: Optional[type[_T]] = None +class GenericInsightMixin(Generic[_T0, _T1, _T2, _T3, _T4]): + _type_arg_0: Optional[type[_T0]] = None + _type_arg_1: Optional[type[_T1]] = None + _type_arg_2: Optional[type[_T2]] = None + _type_arg_3: Optional[type[_T3]] = None + _type_arg_4: Optional[type[_T4]] = None @classmethod def __init_subclass__(cls, **kwargs: Any) -> None: @@ -14,17 +32,70 @@ class GenericInsightMixin(Generic[_T]): origin = get_origin(base) if origin is None or not issubclass(origin, GenericInsightMixin): continue - type_arg = get_args(base)[0] - # Do not set the attribute for GENERIC subclasses! - if not isinstance(type_arg, TypeVar): - cls._type_arg = type_arg - return + type_args = get_args(base) + for idx, arg in enumerate(type_args): + # Do not set the attribute for generics: + if isinstance(arg, TypeVar): + continue + # Do not set `NoneType`: + if isinstance(arg, type) and isinstance(None, arg): + continue + setattr(cls, f"_type_arg_{idx}", arg) + return @classmethod - def _get_type_arg(cls) -> type[_T]: + @overload + def _get_type_arg(cls, idx: Literal[0]) -> type[_T0]: + ... + + @classmethod + @overload + def _get_type_arg(cls, idx: Literal[1]) -> type[_T1]: + ... + + @classmethod + @overload + def _get_type_arg(cls, idx: Literal[2]) -> type[_T2]: + ... + + @classmethod + @overload + def _get_type_arg(cls, idx: Literal[3]) -> type[_T3]: + ... + + @classmethod + @overload + def _get_type_arg(cls, idx: Literal[4]) -> type[_T4]: + ... + + @classmethod + def _get_type_arg( + cls, + idx: Literal[0, 1, 2, 3, 4], + ) -> Union[type[_T0], type[_T1], type[_T2], type[_T3], type[_T4]]: """Returns the type argument of the class (if specified).""" - if cls._type_arg is None: + if idx == 0: + type_ = cls._type_arg_0 + elif idx == 1: + type_ = cls._type_arg_1 + elif idx == 2: # noqa: PLR2004 + type_ = cls._type_arg_2 + elif idx == 3: # noqa: PLR2004 + type_ = cls._type_arg_3 + elif idx == 4: # noqa: PLR2004 + type_ = cls._type_arg_4 + else: + raise ValueError("Only 5 type parameters available") + if type_ is None: raise AttributeError( - f"{cls.__name__} is generic; type argument unspecified" + f"{cls.__name__} is generic; type argument {idx} unspecified" ) - return cls._type_arg + return type_ + + +class GenericInsightMixin1(GenericInsightMixin[_T0, None, None, None, None]): + pass + + +class GenericInsightMixin2(GenericInsightMixin[_T0, _T1, None, None, None]): + pass diff --git a/src/marshmallow_generic/schema.py b/src/marshmallow_generic/schema.py index 8e68ecd..270da53 100644 --- a/src/marshmallow_generic/schema.py +++ b/src/marshmallow_generic/schema.py @@ -10,13 +10,13 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, overlo from marshmallow import Schema -from ._util import GenericInsightMixin +from ._util import GenericInsightMixin1 from .decorators import post_load Model = TypeVar("Model") -class GenericSchema(GenericInsightMixin[Model], Schema): +class GenericSchema(GenericInsightMixin1[Model], Schema): """ Generic schema parameterized by a **`Model`** class. @@ -65,7 +65,7 @@ class GenericSchema(GenericInsightMixin[Model], Schema): Returns: Instance of the schema's **`Model`** initialized with `**data` """ - return self._get_type_arg()(**data) + return self._get_type_arg(0)(**data) if TYPE_CHECKING: diff --git a/tests/test__util.py b/tests/test__util.py index 7d00205..56749eb 100644 --- a/tests/test__util.py +++ b/tests/test__util.py @@ -12,7 +12,11 @@ class GenericInsightMixinTestCase(TestCase): mock_super.return_value = MagicMock(__init_subclass__=mock_super_meth) # Should be `None` by default: - self.assertIsNone(_util.GenericInsightMixin._type_arg) # type: ignore[misc] + self.assertIsNone(_util.GenericInsightMixin._type_arg_0) # type: ignore[misc] + self.assertIsNone(_util.GenericInsightMixin._type_arg_1) # type: ignore[misc] + self.assertIsNone(_util.GenericInsightMixin._type_arg_2) # type: ignore[misc] + self.assertIsNone(_util.GenericInsightMixin._type_arg_3) # type: ignore[misc] + self.assertIsNone(_util.GenericInsightMixin._type_arg_4) # type: ignore[misc] # If the mixin type argument was not specified (still generic), # ensure that the attribute remains `None` on the subclass: @@ -24,30 +28,55 @@ class GenericInsightMixinTestCase(TestCase): class Bar(Generic[t]): pass - class TestSchema1(Bar[str], _util.GenericInsightMixin[t]): + class TestCls(Bar[str], _util.GenericInsightMixin[t, None, int, str, bool]): pass - self.assertIsNone(TestSchema1._type_arg) # type: ignore[misc] + self.assertIsNone(TestCls._type_arg_0) # type: ignore[misc] + self.assertIsNone(TestCls._type_arg_1) # type: ignore[misc] + self.assertIs(int, TestCls._type_arg_2) # type: ignore[misc] + self.assertIs(str, TestCls._type_arg_3) # type: ignore[misc] + self.assertIs(bool, TestCls._type_arg_4) # type: ignore[misc] mock_super.assert_called_once() mock_super_meth.assert_called_once_with() mock_super.reset_mock() mock_super_meth.reset_mock() - # If the mixin type argument was specified, - # ensure it was assigned to the attribute on the child class: + # If the mixin type arguments were omitted, + # ensure the attributes remained `None`: - class TestSchema2(Bar[str], _util.GenericInsightMixin[Foo]): + class UnspecifiedCls(_util.GenericInsightMixin): # type: ignore[type-arg] pass - self.assertIs(Foo, TestSchema2._type_arg) # type: ignore[misc] + self.assertIsNone(UnspecifiedCls._type_arg_0) # type: ignore[misc] + self.assertIsNone(UnspecifiedCls._type_arg_1) # type: ignore[misc] + self.assertIsNone(UnspecifiedCls._type_arg_2) # type: ignore[misc] + self.assertIsNone(UnspecifiedCls._type_arg_3) # type: ignore[misc] + self.assertIsNone(UnspecifiedCls._type_arg_4) # type: ignore[misc] mock_super.assert_called_once() mock_super_meth.assert_called_once_with() def test__get_type_arg(self) -> None: with self.assertRaises(AttributeError): - _util.GenericInsightMixin._get_type_arg() + _util.GenericInsightMixin._get_type_arg(0) - _type = object() - with patch.object(_util.GenericInsightMixin, "_type_arg", new=_type): - self.assertIs(_type, _util.GenericInsightMixin._get_type_arg()) + _type_0 = object() + _type_1 = object() + _type_2 = object() + _type_3 = object() + _type_4 = object() + with patch.multiple( + _util.GenericInsightMixin, + _type_arg_0=_type_0, + _type_arg_1=_type_1, + _type_arg_2=_type_2, + _type_arg_3=_type_3, + _type_arg_4=_type_4, + ): + self.assertIs(_type_0, _util.GenericInsightMixin._get_type_arg(0)) + self.assertIs(_type_1, _util.GenericInsightMixin._get_type_arg(1)) + self.assertIs(_type_2, _util.GenericInsightMixin._get_type_arg(2)) + self.assertIs(_type_3, _util.GenericInsightMixin._get_type_arg(3)) + self.assertIs(_type_4, _util.GenericInsightMixin._get_type_arg(4)) + with self.assertRaises(ValueError): + _util.GenericInsightMixin._get_type_arg(5) # type: ignore[call-overload] diff --git a/tests/test_schema.py b/tests/test_schema.py index 59e17e6..c5e22ae 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,7 +17,7 @@ class GenericSchemaTestCase(TestCase): # Explicit annotation to possibly catch mypy errors: output: Foo = schema_obj.instantiate(mock_data) self.assertIs(mock_cls.return_value, output) - mock__get_type_arg.assert_called_once_with() + mock__get_type_arg.assert_called_once_with(0) mock_cls.assert_called_once_with(**mock_data) def test_load_and_loads(self) -> None: