From 712f7fca7bce172bbc5dca0c045e2c630ef99c8e Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Mon, 13 Mar 2023 00:08:57 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Overload=20`dump`/`dumps`=20methods?= =?UTF-8?q?=20to=20increase=20type=20safety;=20issue=20a=20warning,=20when?= =?UTF-8?q?=20using=20the=20`many`=20parameter=20in=20`=5F=5Finit=5F=5F`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/marshmallow_generic/schema.py | 160 +++++++++++++++++++++++++++++- tests/test_schema.py | 44 ++++++++ 2 files changed, 200 insertions(+), 4 deletions(-) diff --git a/src/marshmallow_generic/schema.py b/src/marshmallow_generic/schema.py index 270da53..f077092 100644 --- a/src/marshmallow_generic/schema.py +++ b/src/marshmallow_generic/schema.py @@ -7,6 +7,7 @@ documentation of [`marshmallow.Schema`][marshmallow.Schema]. from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, overload +from warnings import warn from marshmallow import Schema @@ -43,6 +44,79 @@ class GenericSchema(GenericInsightMixin1[Model], Schema): ``` """ + def __init__( + self, + *, + only: Union[Sequence[str], set[str], None] = None, + exclude: Union[Sequence[str], set[str]] = (), + context: Union[dict[str, Any], None] = None, + load_only: Union[Sequence[str], set[str]] = (), + dump_only: Union[Sequence[str], set[str]] = (), + partial: Union[bool, Sequence[str], set[str]] = False, + unknown: Optional[str] = None, + many: Optional[bool] = None, + ) -> None: + """ + Emits a warning, if the `many` argument is not `None`. + + Otherwise the same as in [`marshmallow.Schema`][marshmallow.Schema]. + + Args: + only: + Whitelist of the declared fields to select when instantiating + the Schema. If `None`, all fields are used. Nested fields can + be represented with dot delimiters. + exclude: + Blacklist of the declared fields to exclude when instantiating + the Schema. If a field appears in both `only` and `exclude`, + it is not used. Nested fields can be represented with dot + delimiters. + context: + Optional context passed to [`Method`] + [marshmallow.fields.Method] and [`Function`] + [marshmallow.fields.Function] fields. + load_only: + Fields to skip during serialization (write-only fields) + dump_only: + Fields to skip during deserialization (read-only fields) + partial: + Whether to ignore missing fields and not require any fields + declared. Propagates down to [`Nested`] + [marshmallow.fields.Nested] fields as well. If its value is an + iterable, only missing fields listed in that iterable will be + ignored. Use dot delimiters to specify nested fields. + unknown: + Whether to exclude, include, or raise an error for unknown + fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. + many: + !!! warning + Specifying this option schema-wide undermines the type + safety that this class aims to provide and passing any + value other than `None` will trigger a warning. Use the + method-specific `many` parameter, when calling + [`dump`][marshmallow_generic.GenericSchema.dump]/ + [`dumps`][marshmallow_generic.GenericSchema.dumps] or + [`load`][marshmallow_generic.GenericSchema.load]/ + [`loads`][marshmallow_generic.GenericSchema.loads] instead. + """ + if many is not None: + warn( + "Setting `many` schema-wide breaks type safety. Use the the " + "`many` parameter of specific methods (like `load`) instead." + ) + else: + many = bool(many) + super().__init__( + only=only, + exclude=exclude, + many=many, + context=context, + load_only=load_only, + dump_only=dump_only, + partial=partial, + unknown=unknown, + ) + @post_load def instantiate(self, data: dict[str, Any], **_kwargs: Any) -> Model: """ @@ -52,10 +126,9 @@ class GenericSchema(GenericInsightMixin1[Model], Schema): [marshmallow_generic.decorators.post_load] hook for the schema. !!! warning - You should probably **not** use this method directly; - no parsing, transformation or validation of any kind is done - in this method. The `data` passed to the **`Model`** constructor - "as is". + You should probably not use this method directly. No parsing, + transformation or validation of any kind is done in this method. + The `data` is passed to the **`Model`** constructor "as is". Args: data: @@ -69,6 +142,85 @@ class GenericSchema(GenericInsightMixin1[Model], Schema): if TYPE_CHECKING: + @overload # type: ignore[override] + def dump( + self, + obj: Iterable[Model], + *, + many: Literal[True], + ) -> list[dict[str, Any]]: + ... + + @overload + def dump( + self, + obj: Model, + *, + many: Optional[Literal[False]] = None, + ) -> dict[str, Any]: + ... + + def dump( + self, + obj: Union[Model, Iterable[Model]], + *, + many: Optional[bool] = None, + ) -> Union[dict[str, Any], list[dict[str, Any]]]: + """ + Serializes **`Model`** objects to native Python data types. + + Same as [`marshmallow.Schema.dump`] + [marshmallow.schema.Schema.dump] at runtime. + + Annotations ensure that type checkers will infer the return type + correctly based on the `many` argument, and also enforce the `obj` + argument to be an a `list` of **`Model`** instances, if `many` is + set to `True` or a single instance of it, if `many` is `False` + (or omitted). + + Args: + obj: + The object or iterable of objects to serialize + many: + Whether to serialize `obj` as a collection. If `None`, the + value for `self.many` is used. + + Returns: + (dict[str, Any]): if `many` is set to `False` + (list[dict[str, Any]]): if `many` is set to `True` + """ + ... + + @overload # type: ignore[override] + def dumps( + self, + obj: Iterable[Model], + *args: Any, + many: Literal[True], + **kwargs: Any, + ) -> str: + ... + + @overload + def dumps( + self, + obj: Model, + *args: Any, + many: Optional[Literal[False]] = None, + **kwargs: Any, + ) -> str: + ... + + def dumps( + self, + obj: Union[Model, Iterable[Model]], + *args: Any, + many: Optional[bool] = None, + **kwargs: Any, + ) -> str: + """Same as [`dump`][marshmallow_generic.GenericSchema.dump], but returns a JSON-encoded string.""" + ... + @overload # type: ignore[override] def load( self, diff --git a/tests/test_schema.py b/tests/test_schema.py index c5e22ae..aed910d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,3 +1,4 @@ +from typing import Any from unittest import TestCase from unittest.mock import MagicMock, patch @@ -5,6 +6,29 @@ from marshmallow_generic import _util, schema class GenericSchemaTestCase(TestCase): + @patch("marshmallow.schema.Schema.__init__") + def test___init__(self, mock_super_init: MagicMock) -> None: + class Foo: + pass + + kwargs: dict[str, Any] = { + "only": object(), + "exclude": object(), + "context": object(), + "load_only": object(), + "dump_only": object(), + "partial": object(), + "unknown": object(), + "many": None, + } + schema.GenericSchema[Foo](**kwargs) + mock_super_init.assert_called_once_with(**kwargs | {"many": False}) + mock_super_init.reset_mock() + kwargs["many"] = True + with self.assertWarns(UserWarning): + schema.GenericSchema[Foo](**kwargs) + mock_super_init.assert_called_once_with(**kwargs) + @patch.object(_util.GenericInsightMixin, "_get_type_arg") def test_instantiate(self, mock__get_type_arg: MagicMock) -> None: mock__get_type_arg.return_value = mock_cls = MagicMock() @@ -20,6 +44,26 @@ class GenericSchemaTestCase(TestCase): mock__get_type_arg.assert_called_once_with(0) mock_cls.assert_called_once_with(**mock_data) + def test_dump_and_dumps(self) -> None: + """Mainly for static type checking purposes.""" + + class Foo: + pass + + class TestSchema(schema.GenericSchema[Foo]): + pass + + foo = Foo() + single: dict[str, Any] = TestSchema().dump(foo) + self.assertDictEqual({}, single) + json_string: str = TestSchema().dumps(foo) + self.assertEqual("{}", json_string) + + multiple: list[dict[str, Any]] = TestSchema().dump([foo], many=True) + self.assertListEqual([{}], multiple) + json_string = TestSchema().dumps([foo], many=True) + self.assertEqual("[{}]", json_string) + def test_load_and_loads(self) -> None: """Mainly for static type checking purposes."""