From a64708dedddb1053ec7fb2c57980e7eb7b6f424d Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Wed, 3 Aug 2022 11:17:05 +0200 Subject: [PATCH] Implement first simple working function and tests --- src/orm2pydantic/sqla.py | 82 ++++++++++++++++++++++++++++++++++ src/orm2pydantic/utils.py | 20 +++++++++ tests/informal.py | 93 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 src/orm2pydantic/sqla.py create mode 100644 src/orm2pydantic/utils.py create mode 100644 tests/informal.py diff --git a/src/orm2pydantic/sqla.py b/src/orm2pydantic/sqla.py new file mode 100644 index 0000000..ec7a7a8 --- /dev/null +++ b/src/orm2pydantic/sqla.py @@ -0,0 +1,82 @@ +from typing import Container, Type + +from pydantic import create_model, BaseConfig, Field +from pydantic.fields import FieldInfo + +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import ColumnProperty, RelationshipProperty, Mapper +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.sql.schema import Column, ColumnDefault + +from .utils import resolve_dotted_path + + +__all__ = [ + 'field_from_column', + 'from_sqla' +] + + +FieldDef = tuple[type, FieldInfo] + + +_local_namespace = {} + + +class OrmConfig(BaseConfig): + orm_mode = True + + +def field_from_column(column: Column) -> FieldDef: + try: + field_type = column.type.impl.python_type + except AttributeError: + try: + field_type = column.type.python_type + except AttributeError: + raise AssertionError(f"Could not infer Python type for {column.key}") + default = ... if column.default is None and not column.nullable else column.default + if isinstance(default, ColumnDefault): + if default.is_scalar: + field_info = Field(default=default.arg) + else: + assert callable(default.arg) + dotted_path = default.arg.__module__ + '.' + default.arg.__name__ + factory = resolve_dotted_path(dotted_path) + assert callable(factory) + field_info = Field(default_factory=factory) + else: + field_info = Field(default=default) + return field_type, field_info + + +def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False, + config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), + add_fields: dict[str, FieldDef] = None): + assert isinstance(db_model, DeclarativeMeta) + assert not (incl_one_to_many and incl_many_to_one) + fields = {} + for attr in inspect(db_model).attrs: + if attr.key in exclude: + continue + if isinstance(attr, ColumnProperty): + assert len(attr.columns) == 1 + column = attr.columns[0] + fields[attr.key] = field_from_column(column) + elif isinstance(attr, RelationshipProperty): + related = attr.mapper + assert isinstance(related, Mapper) + if incl_many_to_one and attr.direction.name == 'MANYTOONE': + fields[attr.key] = (related.class_.__name__, Field(default=None)) + if incl_one_to_many and attr.direction.name == 'ONETOMANY': + fields[attr.key] = (list[related.class_.__name__], Field(default=None)) + else: + raise AssertionError("Unknown attr type", attr) + if add_fields is not None: + fields |= add_fields + name = db_model.__name__ + pydantic_model = create_model(name, __config__=config, **fields) + pydantic_model.__name__ = name + pydantic_model.update_forward_refs(**_local_namespace) + _local_namespace[name] = pydantic_model + return pydantic_model diff --git a/src/orm2pydantic/utils.py b/src/orm2pydantic/utils.py new file mode 100644 index 0000000..3ec1298 --- /dev/null +++ b/src/orm2pydantic/utils.py @@ -0,0 +1,20 @@ +from importlib import import_module + + +def resolve_dotted_path(dotted_path: str) -> object: + """ + Resolves a dotted path to a global object and returns that object. + + Algorithm shamelessly stolen from the `logging.config` module from the standard library. + """ + names = dotted_path.split('.') + module_name = names.pop(0) + found = import_module(module_name) + for name in names: + try: + found = getattr(found, name) + except AttributeError: + module_name += f'.{name}' + import_module(module_name) + found = getattr(found, name) + return found diff --git a/tests/informal.py b/tests/informal.py new file mode 100644 index 0000000..14d2beb --- /dev/null +++ b/tests/informal.py @@ -0,0 +1,93 @@ +from sqlalchemy.engine.create import create_engine +from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm.session import Session +from sqlalchemy.sql.functions import now as db_now +from sqlalchemy.sql.schema import Column, ForeignKey as FKey +from sqlalchemy.sql.sqltypes import Integer, String, TIMESTAMP, Unicode + +from orm2pydantic.sqla import from_sqla + + +ORMBase = declarative_base() +engine = create_engine("sqlite://") + + +def default_factory() -> str: return '1' + + +class AbstractBase(ORMBase): + __abstract__ = True + + NON_REPR_FIELDS = ['id', 'date_created', 'date_updated'] + + date_created = Column(TIMESTAMP(timezone=False), server_default=db_now()) + date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) + + +class StateProvince(AbstractBase): + __tablename__ = 'state_province' + + id = Column(Integer, primary_key=True) + country = Column(String(2), nullable=False, index=True) + name = Column(Unicode(255), nullable=False, index=True) + + cities = relationship('City', backref='state_province', lazy='selectin') + + +class City(AbstractBase): + __tablename__ = 'city' + + id = Column(Integer, primary_key=True) + state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True) + zip_code = Column(String(5), nullable=False, index=True) + name = Column(Unicode(255), nullable=False, index=True) + + streets = relationship('Street', backref='city', lazy='selectin') + + +class Street(AbstractBase): + __tablename__ = 'street' + + id = Column(Integer, primary_key=True) + city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True) + name = Column(Unicode(255), nullable=False, index=True) + + addresses = relationship('Address', backref='street', lazy='selectin') + + +class Address(AbstractBase): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True) + house_number = Column(String(8), nullable=False, default=default_factory) + supplement = Column(String(255)) + + +def main_test() -> None: + AbstractBase.metadata.create_all(engine) + + from_sqla(StateProvince) + from_sqla(City) + from_sqla(Street) + _PydanticAddress = from_sqla(Address) + + with Session(engine) as session: + bavaria = StateProvince(country="de", name="Bavaria") + munich = City(zip_code='80333', name="Munich") + bavaria.cities.append(munich) + maximilian_street = Street(name="Maximilianstrasse") + munich.streets.append(maximilian_street) + some_address = Address() + maximilian_street.addresses.append(some_address) + session.add_all([bavaria, munich, maximilian_street, some_address]) + session.commit() + + address = _PydanticAddress.from_orm(some_address) + + assert address.house_number == '1' + assert address.street.city.state_province.name == "Bavaria" + + +if __name__ == '__main__': + main_test()