From 79c2f16ade3645ba3b565cfa66299c29d619f736 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sat, 6 Aug 2022 16:31:38 +0200 Subject: [PATCH] Add `__validators__` and `__base__` params; improve types --- setup.cfg | 2 +- src/orm2pydantic/sqla.py | 47 ++++++++++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/setup.cfg b/setup.cfg index ac6121f..3442962 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = orm2pydantic -version = 0.1.0 +version = 0.1.1 author = Daniil Fajnberg author_email = mail@daniil.fajnberg.de description = Convert SQLAlchemy models to Pydantic models diff --git a/src/orm2pydantic/sqla.py b/src/orm2pydantic/sqla.py index 7053087..150a9ba 100644 --- a/src/orm2pydantic/sqla.py +++ b/src/orm2pydantic/sqla.py @@ -18,7 +18,7 @@ __doc__ = """ Functions for turning SQLAlchemy objects into corresponding Pydantic objects. """ -from typing import Container, Type +from typing import Any, Callable, Container, Optional, Type, TypeVar from pydantic import create_model, BaseConfig, BaseModel, Field from pydantic.fields import FieldInfo @@ -39,7 +39,8 @@ __all__ = [ FieldDef = tuple[type, FieldInfo] - +ModelT = TypeVar('ModelT', bound=BaseModel) +ValidatorT = Callable[[BaseModel, ...], Any] _local_namespace: dict[str, BaseModel] = {} @@ -104,9 +105,16 @@ def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef: return list[rel_prop.mapper.class_.__name__], Field(default=None) -def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), - incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None, - add_local_ns: dict[str, BaseModel] = None): +def sqla2pydantic( + orm_model: Type[DeclarativeMeta], + exclude: Container[str] = (), + incl_relationships: bool = True, + add_fields: Optional[dict[str, FieldDef]] = None, + add_local_ns: Optional[dict[str, BaseModel]] = None, + __config__: Type[BaseConfig] = OrmConfig, + __base__: Optional[Type[ModelT]] = None, + __validators__: Optional[dict[str, ValidatorT]] = None +) -> Type[ModelT]: """ Takes an SQLAlchemy ORM model class and returns a matching Pydantic model class. @@ -118,11 +126,8 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or namespace of already created Pydantic models. Args: - db_model: + orm_model: The SQLAlchemy model; must be an instance of `DeclarativeMeta` - config (optional): - The inner model config class passed via the `__config__` parameter to `pydantic.create_model`; - by default the only explicit setting is `orm_mode = True`. exclude (optional): A container of strings, each of which represents the name of a field not to create in the Pydantic model; by default all fields of the original database model will be converted to Pydantic model fields. @@ -138,13 +143,22 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or May be passed a dictionary mapping additional Pydantic model names to the corresponding classes; these will be passed to the `BaseModel.update_forward_refs` method in addition to those being tracked internally anyway. + __config__ (optional): + The inner model config class passed via the `__config__` parameter to `pydantic.create_model`; + by default the only explicit setting is `orm_mode = True`. + __base__ (optional): + The base class for the new model to inherit from; + passed via the `__base__` parameter to `pydantic.create_model`. + __validators__ (optional): + Dictionary mapping method names to validation class methods that are decorated with `@pydantic.validator`; + passed via the `__validators__` parameter to `pydantic.create_model`. Returns: - Pydantic model class + Pydantic model class with fields corresponding to the specified ORM `db_model` """ - assert isinstance(db_model, DeclarativeMeta) + assert isinstance(orm_model, DeclarativeMeta) fields = {} - for attr in inspect(db_model).attrs: + for attr in inspect(orm_model).attrs: if attr.key in exclude: continue if isinstance(attr, ColumnProperty): @@ -156,9 +170,10 @@ def sqla2pydantic(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = Or raise AssertionError("Unknown attr type", attr) if add_fields is not None: fields |= add_fields - name = db_model.__name__ - pydantic_model = create_model(name, __config__=config, **fields) - pydantic_model.__name__ = name + model_name = orm_model.__name__ + pydantic_model = create_model(model_name, __config__=__config__, __base__=__base__, + __validators__=__validators__, **fields) + pydantic_model.__name__ = model_name pydantic_model.update_forward_refs(**_local_namespace if add_local_ns is None else _local_namespace | add_local_ns) - _local_namespace[name] = pydantic_model + _local_namespace[model_name] = pydantic_model return pydantic_model