diff --git a/src/marshmallow_generic/_util.py b/src/marshmallow_generic/_util.py new file mode 100644 index 0000000..bf9f152 --- /dev/null +++ b/src/marshmallow_generic/_util.py @@ -0,0 +1,31 @@ +from typing import Any, Generic, Optional, TypeVar, get_args, get_origin + + +_T = TypeVar("_T") + + +class GenericInsightMixin(Generic[_T]): + _type_arg: Optional[type[_T]] = None + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Saves the type argument in the `_type_arg` class attribute.""" + super().__init_subclass__(**kwargs) + for base in cls.__orig_bases__: # type: ignore[attr-defined] + origin = get_origin(base) + if origin is None or not issubclass(origin, GenericInsightMixin): + continue + type_arg = get_args(base)[0] + # Do not set the attribute for GENERIC subclasses! + if not isinstance(type_arg, TypeVar): + cls._type_arg = type_arg + return + + @classmethod + def _get_type_arg(cls) -> type[_T]: + """Returns the type argument of the class (if specified).""" + if cls._type_arg is None: + raise AttributeError( + f"{cls.__name__} is generic; type argument unspecified" + ) + return cls._type_arg diff --git a/tests/test__util.py b/tests/test__util.py new file mode 100644 index 0000000..7d00205 --- /dev/null +++ b/tests/test__util.py @@ -0,0 +1,53 @@ +from typing import Generic, TypeVar +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from marshmallow_generic import _util + + +class GenericInsightMixinTestCase(TestCase): + @patch.object(_util, "super") + def test___init_subclass__(self, mock_super: MagicMock) -> None: + mock_super_meth = MagicMock() + mock_super.return_value = MagicMock(__init_subclass__=mock_super_meth) + + # Should be `None` by default: + self.assertIsNone(_util.GenericInsightMixin._type_arg) # type: ignore[misc] + + # If the mixin type argument was not specified (still generic), + # ensure that the attribute remains `None` on the subclass: + t = TypeVar("t") + + class Foo: + pass + + class Bar(Generic[t]): + pass + + class TestSchema1(Bar[str], _util.GenericInsightMixin[t]): + pass + + self.assertIsNone(TestSchema1._type_arg) # type: ignore[misc] + mock_super.assert_called_once() + mock_super_meth.assert_called_once_with() + + mock_super.reset_mock() + mock_super_meth.reset_mock() + + # If the mixin type argument was specified, + # ensure it was assigned to the attribute on the child class: + + class TestSchema2(Bar[str], _util.GenericInsightMixin[Foo]): + pass + + self.assertIs(Foo, TestSchema2._type_arg) # type: ignore[misc] + mock_super.assert_called_once() + mock_super_meth.assert_called_once_with() + + def test__get_type_arg(self) -> None: + with self.assertRaises(AttributeError): + _util.GenericInsightMixin._get_type_arg() + + _type = object() + with patch.object(_util.GenericInsightMixin, "_type_arg", new=_type): + self.assertIs(_type, _util.GenericInsightMixin._get_type_arg())