Add docstrings; add option for additional namespace

This commit is contained in:
Daniil Fajnberg 2022-08-06 10:48:14 +02:00
parent 19224ea34a
commit 9a36634fb6

View File

@ -1,6 +1,6 @@
from typing import Container, Type
from pydantic import create_model, BaseConfig, Field
from pydantic import create_model, BaseConfig, BaseModel, Field
from pydantic.fields import FieldInfo
from sqlalchemy.inspection import inspect
@ -21,7 +21,7 @@ __all__ = [
FieldDef = tuple[type, FieldInfo]
_local_namespace = {}
_local_namespace: dict[str, BaseModel] = {}
class OrmConfig(BaseConfig):
@ -29,6 +29,16 @@ class OrmConfig(BaseConfig):
def field_from_column(col_prop: ColumnProperty) -> FieldDef:
"""
Takes a regular field of an SQLAlchemy ORM model and returns a corresponding Pydantic field definition.
Args:
col_prop: Instance of `sqlalchemy.orm.ColumnProperty` (i.e. not a relationship field)
Returns:
2-tuple with the first element being the Python type of the field and the second being a
`pydantic.fields.FieldInfo` instance with the correct `default` or `default_factory` parameter.
"""
assert len(col_prop.columns) == 1
column: Column = col_prop.columns[0]
try:
@ -54,6 +64,19 @@ def field_from_column(col_prop: ColumnProperty) -> FieldDef:
def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
"""
Takes a relationship field of an SQLAlchemy ORM model and returns a corresponding Pydantic field definition.
A Many-to-One relationship results in the type of the field being simply the name of the related model class,
whereas a One-to-Many relationship results in the type being a list parametrized with the name of that class.
Args:
rel_prop: Instance of `sqlalchemy.orm.RelationshipProperty` (i.e. not a regular field)
Returns:
2-tuple with the first element being the type of the field and the second being a
`pydantic.fields.FieldInfo` instance with the `default` parameter set to `None`.
"""
assert isinstance(rel_prop.mapper, Mapper)
if rel_prop.direction.name == 'MANYTOONE':
return rel_prop.mapper.class_.__name__, Field(default=None)
@ -62,7 +85,43 @@ def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None):
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None,
add_local_ns: dict[str, BaseModel] = None):
"""
Takes an SQLAlchemy ORM model class and returns a matching Pydantic model class.
Makes use of the `pydantic.create_model` function.
Handles default values set on the database model properly, including factory functions.
Can handle **acyclic** relationships between models by dynamically updating forward references using a local
namespace of already created Pydantic models.
Args:
db_model:
The SQLAlchemy model; must be an instance of `DeclarativeMeta`
config (optional):
The inner model config class passed via the `__config__` parameter to `pydantic.create_model`;
by default the only explicit setting is `orm_mode = True`.
exclude (optional):
A container of strings, each of which represents the name of a field not to create in the Pydantic model;
by default all fields of the original database model will be converted to Pydantic model fields.
incl_relationships (optional):
If set to `False`, fields representing relationships of the database model will not be converted.
Note that including all relationships may result in circular relationships that Pydantic cannot handle.
It may be advisable to selectively exclude certain relationship fields to avoid such issues.
Set to `True` by default.
add_fields (optional):
May be passed a dictionary mapping additional field names (not present in the database model) to appropriate
Pydantic field definitions; those fields will then also be present on the resulting Pydantic model.
add_local_ns (optional):
May be passed a dictionary mapping additional Pydantic model names to the corresponding classes;
these will be passed to the `BaseModel.update_forward_refs` method in addition to those being tracked
internally anyway.
Returns:
Pydantic model class
"""
assert isinstance(db_model, DeclarativeMeta)
fields = {}
for attr in inspect(db_model).attrs:
@ -80,6 +139,6 @@ def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmCon
name = db_model.__name__
pydantic_model = create_model(name, __config__=config, **fields)
pydantic_model.__name__ = name
pydantic_model.update_forward_refs(**_local_namespace)
pydantic_model.update_forward_refs(**_local_namespace if add_local_ns is None else _local_namespace | add_local_ns)
_local_namespace[name] = pydantic_model
return pydantic_model