From 9a36634fb6984bb833d3a636587c9cc1fedf1818 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sat, 6 Aug 2022 10:48:14 +0200 Subject: [PATCH] Add docstrings; add option for additional namespace --- src/orm2pydantic/sqla.py | 67 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/src/orm2pydantic/sqla.py b/src/orm2pydantic/sqla.py index 12e7feb..ff778f1 100644 --- a/src/orm2pydantic/sqla.py +++ b/src/orm2pydantic/sqla.py @@ -1,6 +1,6 @@ from typing import Container, Type -from pydantic import create_model, BaseConfig, Field +from pydantic import create_model, BaseConfig, BaseModel, Field from pydantic.fields import FieldInfo from sqlalchemy.inspection import inspect @@ -21,7 +21,7 @@ __all__ = [ FieldDef = tuple[type, FieldInfo] -_local_namespace = {} +_local_namespace: dict[str, BaseModel] = {} class OrmConfig(BaseConfig): @@ -29,6 +29,16 @@ class OrmConfig(BaseConfig): def field_from_column(col_prop: ColumnProperty) -> FieldDef: + """ + Takes a regular field of an SQLAlchemy ORM model and returns a corresponding Pydantic field definition. + + Args: + col_prop: Instance of `sqlalchemy.orm.ColumnProperty` (i.e. not a relationship field) + + Returns: + 2-tuple with the first element being the Python type of the field and the second being a + `pydantic.fields.FieldInfo` instance with the correct `default` or `default_factory` parameter. + """ assert len(col_prop.columns) == 1 column: Column = col_prop.columns[0] try: @@ -54,6 +64,19 @@ def field_from_column(col_prop: ColumnProperty) -> FieldDef: def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef: + """ + Takes a relationship field of an SQLAlchemy ORM model and returns a corresponding Pydantic field definition. + + A Many-to-One relationship results in the type of the field being simply the name of the related model class, + whereas a One-to-Many relationship results in the type being a list parametrized with the name of that class. + + Args: + rel_prop: Instance of `sqlalchemy.orm.RelationshipProperty` (i.e. not a regular field) + + Returns: + 2-tuple with the first element being the type of the field and the second being a + `pydantic.fields.FieldInfo` instance with the `default` parameter set to `None`. + """ assert isinstance(rel_prop.mapper, Mapper) if rel_prop.direction.name == 'MANYTOONE': return rel_prop.mapper.class_.__name__, Field(default=None) @@ -62,7 +85,43 @@ def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef: def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), - incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None): + incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None, + add_local_ns: dict[str, BaseModel] = None): + """ + Takes an SQLAlchemy ORM model class and returns a matching Pydantic model class. + + Makes use of the `pydantic.create_model` function. + + Handles default values set on the database model properly, including factory functions. + + Can handle **acyclic** relationships between models by dynamically updating forward references using a local + namespace of already created Pydantic models. + + Args: + db_model: + The SQLAlchemy model; must be an instance of `DeclarativeMeta` + config (optional): + The inner model config class passed via the `__config__` parameter to `pydantic.create_model`; + by default the only explicit setting is `orm_mode = True`. + exclude (optional): + A container of strings, each of which represents the name of a field not to create in the Pydantic model; + by default all fields of the original database model will be converted to Pydantic model fields. + incl_relationships (optional): + If set to `False`, fields representing relationships of the database model will not be converted. + Note that including all relationships may result in circular relationships that Pydantic cannot handle. + It may be advisable to selectively exclude certain relationship fields to avoid such issues. + Set to `True` by default. + add_fields (optional): + May be passed a dictionary mapping additional field names (not present in the database model) to appropriate + Pydantic field definitions; those fields will then also be present on the resulting Pydantic model. + add_local_ns (optional): + May be passed a dictionary mapping additional Pydantic model names to the corresponding classes; + these will be passed to the `BaseModel.update_forward_refs` method in addition to those being tracked + internally anyway. + + Returns: + Pydantic model class + """ assert isinstance(db_model, DeclarativeMeta) fields = {} for attr in inspect(db_model).attrs: @@ -80,6 +139,6 @@ def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmCon name = db_model.__name__ pydantic_model = create_model(name, __config__=config, **fields) pydantic_model.__name__ = name - pydantic_model.update_forward_refs(**_local_namespace) + pydantic_model.update_forward_refs(**_local_namespace if add_local_ns is None else _local_namespace | add_local_ns) _local_namespace[name] = pydantic_model return pydantic_model