diff --git a/src/marshmallow_generic/schema.py b/src/marshmallow_generic/schema.py new file mode 100644 index 0000000..c857bc0 --- /dev/null +++ b/src/marshmallow_generic/schema.py @@ -0,0 +1,68 @@ +"""Definition of the `GenericSchema` base class.""" + +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, overload + +from marshmallow import Schema + +from ._util import GenericInsightMixin +from .decorators import post_load + +_T = TypeVar("_T") + + +class GenericSchema(GenericInsightMixin[_T], Schema): + """ + Schema parameterized by the class it deserializes data to. + + Registers a `post_load` hook to pass validated data to the constructor + of the specified class. + + Requires a specific (non-generic) class to be passed as the type argument + for deserialization to work properly. + """ + + @post_load + def instantiate(self, data: dict[str, Any], **_kwargs: Any) -> _T: + """Unpacks `data` into the constructor of the specified type.""" + return self._get_type_arg()(**data) + + if TYPE_CHECKING: + + @overload # type: ignore[override] + def load( + self, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Literal[True], + partial: Union[bool, Sequence[str], set[str], None] = None, + unknown: Optional[str] = None, + ) -> list[_T]: + ... + + @overload + def load( + self, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Optional[Literal[False]] = None, + partial: Union[bool, Sequence[str], set[str], None] = None, + unknown: Optional[str] = None, + ) -> _T: + ... + + def load( + self, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Optional[bool] = None, + partial: Union[bool, Sequence[str], set[str], None] = None, + unknown: Optional[str] = None, + ) -> Union[list[_T], _T]: + """ + Same as `marshmallow.Schema.load` at runtime. + + Annotations ensure that type checkers will infer the return type + correctly based on the type argument passed to a specific subclass. + """ + ... diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..0c4f6e2 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,37 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from marshmallow_generic import _util, schema + + +class GenericSchemaTestCase(TestCase): + @patch.object(_util.GenericInsightMixin, "_get_type_arg") + def test_instantiate(self, mock__get_type_arg: MagicMock) -> None: + mock__get_type_arg.return_value = mock_cls = MagicMock() + mock_data = {"foo": "bar", "spam": 123} + + class Foo: + pass + + schema_obj = schema.GenericSchema[Foo]() + # Explicit annotation to possibly catch mypy errors: + output: Foo = schema_obj.instantiate(mock_data) + self.assertIs(mock_cls.return_value, output) + mock__get_type_arg.assert_called_once_with() + mock_cls.assert_called_once_with(**mock_data) + + def test_load(self) -> None: + """Mainly for static type checking purposes.""" + + class Foo: + pass + + class TestSchema(schema.GenericSchema[Foo]): + pass + + single: Foo = TestSchema().load({}) + self.assertIsInstance(single, Foo) + + multiple: list[Foo] = TestSchema().load([{}], many=True) + self.assertIsInstance(multiple, list) + self.assertIsInstance(multiple[0], Foo)