Overload dump/dumps methods to increase type safety;

issue a warning, when using the `many` parameter in `__init__`
This commit is contained in:
Daniil Fajnberg 2023-03-13 00:08:57 +01:00
parent 84fa2d2cd9
commit 712f7fca7b
Signed by: daniil-berg
GPG Key ID: BE187C50903BEE97
2 changed files with 200 additions and 4 deletions

View File

@ -7,6 +7,7 @@ documentation of [`marshmallow.Schema`][marshmallow.Schema].
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, overload
from warnings import warn
from marshmallow import Schema
@ -43,6 +44,79 @@ class GenericSchema(GenericInsightMixin1[Model], Schema):
```
"""
def __init__(
self,
*,
only: Union[Sequence[str], set[str], None] = None,
exclude: Union[Sequence[str], set[str]] = (),
context: Union[dict[str, Any], None] = None,
load_only: Union[Sequence[str], set[str]] = (),
dump_only: Union[Sequence[str], set[str]] = (),
partial: Union[bool, Sequence[str], set[str]] = False,
unknown: Optional[str] = None,
many: Optional[bool] = None,
) -> None:
"""
Emits a warning, if the `many` argument is not `None`.
Otherwise the same as in [`marshmallow.Schema`][marshmallow.Schema].
Args:
only:
Whitelist of the declared fields to select when instantiating
the Schema. If `None`, all fields are used. Nested fields can
be represented with dot delimiters.
exclude:
Blacklist of the declared fields to exclude when instantiating
the Schema. If a field appears in both `only` and `exclude`,
it is not used. Nested fields can be represented with dot
delimiters.
context:
Optional context passed to [`Method`]
[marshmallow.fields.Method] and [`Function`]
[marshmallow.fields.Function] fields.
load_only:
Fields to skip during serialization (write-only fields)
dump_only:
Fields to skip during deserialization (read-only fields)
partial:
Whether to ignore missing fields and not require any fields
declared. Propagates down to [`Nested`]
[marshmallow.fields.Nested] fields as well. If its value is an
iterable, only missing fields listed in that iterable will be
ignored. Use dot delimiters to specify nested fields.
unknown:
Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
many:
!!! warning
Specifying this option schema-wide undermines the type
safety that this class aims to provide and passing any
value other than `None` will trigger a warning. Use the
method-specific `many` parameter, when calling
[`dump`][marshmallow_generic.GenericSchema.dump]/
[`dumps`][marshmallow_generic.GenericSchema.dumps] or
[`load`][marshmallow_generic.GenericSchema.load]/
[`loads`][marshmallow_generic.GenericSchema.loads] instead.
"""
if many is not None:
warn(
"Setting `many` schema-wide breaks type safety. Use the the "
"`many` parameter of specific methods (like `load`) instead."
)
else:
many = bool(many)
super().__init__(
only=only,
exclude=exclude,
many=many,
context=context,
load_only=load_only,
dump_only=dump_only,
partial=partial,
unknown=unknown,
)
@post_load
def instantiate(self, data: dict[str, Any], **_kwargs: Any) -> Model:
"""
@ -52,10 +126,9 @@ class GenericSchema(GenericInsightMixin1[Model], Schema):
[marshmallow_generic.decorators.post_load] hook for the schema.
!!! warning
You should probably **not** use this method directly;
no parsing, transformation or validation of any kind is done
in this method. The `data` passed to the **`Model`** constructor
"as is".
You should probably not use this method directly. No parsing,
transformation or validation of any kind is done in this method.
The `data` is passed to the **`Model`** constructor "as is".
Args:
data:
@ -69,6 +142,85 @@ class GenericSchema(GenericInsightMixin1[Model], Schema):
if TYPE_CHECKING:
@overload # type: ignore[override]
def dump(
self,
obj: Iterable[Model],
*,
many: Literal[True],
) -> list[dict[str, Any]]:
...
@overload
def dump(
self,
obj: Model,
*,
many: Optional[Literal[False]] = None,
) -> dict[str, Any]:
...
def dump(
self,
obj: Union[Model, Iterable[Model]],
*,
many: Optional[bool] = None,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""
Serializes **`Model`** objects to native Python data types.
Same as [`marshmallow.Schema.dump`]
[marshmallow.schema.Schema.dump] at runtime.
Annotations ensure that type checkers will infer the return type
correctly based on the `many` argument, and also enforce the `obj`
argument to be an a `list` of **`Model`** instances, if `many` is
set to `True` or a single instance of it, if `many` is `False`
(or omitted).
Args:
obj:
The object or iterable of objects to serialize
many:
Whether to serialize `obj` as a collection. If `None`, the
value for `self.many` is used.
Returns:
(dict[str, Any]): if `many` is set to `False`
(list[dict[str, Any]]): if `many` is set to `True`
"""
...
@overload # type: ignore[override]
def dumps(
self,
obj: Iterable[Model],
*args: Any,
many: Literal[True],
**kwargs: Any,
) -> str:
...
@overload
def dumps(
self,
obj: Model,
*args: Any,
many: Optional[Literal[False]] = None,
**kwargs: Any,
) -> str:
...
def dumps(
self,
obj: Union[Model, Iterable[Model]],
*args: Any,
many: Optional[bool] = None,
**kwargs: Any,
) -> str:
"""Same as [`dump`][marshmallow_generic.GenericSchema.dump], but returns a JSON-encoded string."""
...
@overload # type: ignore[override]
def load(
self,

View File

@ -1,3 +1,4 @@
from typing import Any
from unittest import TestCase
from unittest.mock import MagicMock, patch
@ -5,6 +6,29 @@ from marshmallow_generic import _util, schema
class GenericSchemaTestCase(TestCase):
@patch("marshmallow.schema.Schema.__init__")
def test___init__(self, mock_super_init: MagicMock) -> None:
class Foo:
pass
kwargs: dict[str, Any] = {
"only": object(),
"exclude": object(),
"context": object(),
"load_only": object(),
"dump_only": object(),
"partial": object(),
"unknown": object(),
"many": None,
}
schema.GenericSchema[Foo](**kwargs)
mock_super_init.assert_called_once_with(**kwargs | {"many": False})
mock_super_init.reset_mock()
kwargs["many"] = True
with self.assertWarns(UserWarning):
schema.GenericSchema[Foo](**kwargs)
mock_super_init.assert_called_once_with(**kwargs)
@patch.object(_util.GenericInsightMixin, "_get_type_arg")
def test_instantiate(self, mock__get_type_arg: MagicMock) -> None:
mock__get_type_arg.return_value = mock_cls = MagicMock()
@ -20,6 +44,26 @@ class GenericSchemaTestCase(TestCase):
mock__get_type_arg.assert_called_once_with(0)
mock_cls.assert_called_once_with(**mock_data)
def test_dump_and_dumps(self) -> None:
"""Mainly for static type checking purposes."""
class Foo:
pass
class TestSchema(schema.GenericSchema[Foo]):
pass
foo = Foo()
single: dict[str, Any] = TestSchema().dump(foo)
self.assertDictEqual({}, single)
json_string: str = TestSchema().dumps(foo)
self.assertEqual("{}", json_string)
multiple: list[dict[str, Any]] = TestSchema().dump([foo], many=True)
self.assertListEqual([{}], multiple)
json_string = TestSchema().dumps([foo], many=True)
self.assertEqual("[{}]", json_string)
def test_load_and_loads(self) -> None:
"""Mainly for static type checking purposes."""