Add `__validators__` and `__base__` params; improve types

This commit is contained in:
Daniil Fajnberg 2022-08-06 16:31:38 +02:00
parent 23d55902b0
commit 79c2f16ade
2 changed files with 32 additions and 17 deletions

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = orm2pydantic name = orm2pydantic
version = 0.1.0 version = 0.1.1
author = Daniil Fajnberg author = Daniil Fajnberg
author_email = mail@daniil.fajnberg.de author_email = mail@daniil.fajnberg.de
description = Convert SQLAlchemy models to Pydantic models description = Convert SQLAlchemy models to Pydantic models

View File

@ -18,7 +18,7 @@ __doc__ = """
Functions for turning SQLAlchemy objects into corresponding Pydantic objects. Functions for turning SQLAlchemy objects into corresponding Pydantic objects.
""" """
from typing import Container, Type from typing import Any, Callable, Container, Optional, Type, TypeVar
from pydantic import create_model, BaseConfig, BaseModel, Field from pydantic import create_model, BaseConfig, BaseModel, Field
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -39,7 +39,8 @@ __all__ = [
FieldDef = tuple[type, FieldInfo] FieldDef = tuple[type, FieldInfo]
ModelT = TypeVar('ModelT', bound=BaseModel)
ValidatorT = Callable[[BaseModel, ...], Any]
_local_namespace: dict[str, BaseModel] = {} _local_namespace: dict[str, BaseModel] = {}
@ -104,9 +105,16 @@ def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
return list[rel_prop.mapper.class_.__name__], Field(default=None) return list[rel_prop.mapper.class_.__name__], Field(default=None)
def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), def sqla2pydantic(
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None, orm_model: Type[DeclarativeMeta],
add_local_ns: dict[str, BaseModel] = None): exclude: Container[str] = (),
incl_relationships: bool = True,
add_fields: Optional[dict[str, FieldDef]] = None,
add_local_ns: Optional[dict[str, BaseModel]] = None,
__config__: Type[BaseConfig] = OrmConfig,
__base__: Optional[Type[ModelT]] = None,
__validators__: Optional[dict[str, ValidatorT]] = None
) -> Type[ModelT]:
""" """
Takes an SQLAlchemy ORM model class and returns a matching Pydantic model class. Takes an SQLAlchemy ORM model class and returns a matching Pydantic model class.
@ -118,11 +126,8 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or
namespace of already created Pydantic models. namespace of already created Pydantic models.
Args: Args:
db_model: orm_model:
The SQLAlchemy model; must be an instance of `DeclarativeMeta` The SQLAlchemy model; must be an instance of `DeclarativeMeta`
config (optional):
The inner model config class passed via the `__config__` parameter to `pydantic.create_model`;
by default the only explicit setting is `orm_mode = True`.
exclude (optional): exclude (optional):
A container of strings, each of which represents the name of a field not to create in the Pydantic model; A container of strings, each of which represents the name of a field not to create in the Pydantic model;
by default all fields of the original database model will be converted to Pydantic model fields. by default all fields of the original database model will be converted to Pydantic model fields.
@ -138,13 +143,22 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or
May be passed a dictionary mapping additional Pydantic model names to the corresponding classes; May be passed a dictionary mapping additional Pydantic model names to the corresponding classes;
these will be passed to the `BaseModel.update_forward_refs` method in addition to those being tracked these will be passed to the `BaseModel.update_forward_refs` method in addition to those being tracked
internally anyway. internally anyway.
__config__ (optional):
The inner model config class passed via the `__config__` parameter to `pydantic.create_model`;
by default the only explicit setting is `orm_mode = True`.
__base__ (optional):
The base class for the new model to inherit from;
passed via the `__base__` parameter to `pydantic.create_model`.
__validators__ (optional):
Dictionary mapping method names to validation class methods that are decorated with `@pydantic.validator`;
passed via the `__validators__` parameter to `pydantic.create_model`.
Returns: Returns:
Pydantic model class Pydantic model class with fields corresponding to the specified ORM `db_model`
""" """
assert isinstance(db_model, DeclarativeMeta) assert isinstance(orm_model, DeclarativeMeta)
fields = {} fields = {}
for attr in inspect(db_model).attrs: for attr in inspect(orm_model).attrs:
if attr.key in exclude: if attr.key in exclude:
continue continue
if isinstance(attr, ColumnProperty): if isinstance(attr, ColumnProperty):
@ -156,9 +170,10 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or
raise AssertionError("Unknown attr type", attr) raise AssertionError("Unknown attr type", attr)
if add_fields is not None: if add_fields is not None:
fields |= add_fields fields |= add_fields
name = db_model.__name__ model_name = orm_model.__name__
pydantic_model = create_model(name, __config__=config, **fields) pydantic_model = create_model(model_name, __config__=__config__, __base__=__base__,
pydantic_model.__name__ = name __validators__=__validators__, **fields)
pydantic_model.__name__ = model_name
pydantic_model.update_forward_refs(**_local_namespace if add_local_ns is None else _local_namespace | add_local_ns) pydantic_model.update_forward_refs(**_local_namespace if add_local_ns is None else _local_namespace | add_local_ns)
_local_namespace[name] = pydantic_model _local_namespace[model_name] = pydantic_model
return pydantic_model return pydantic_model