Implement first simple working function and tests

This commit is contained in:
Daniil Fajnberg 2022-08-03 11:17:05 +02:00
parent 5f823cb552
commit a64708dedd
3 changed files with 195 additions and 0 deletions

82
src/orm2pydantic/sqla.py Normal file
View File

@ -0,0 +1,82 @@
from typing import Container, Type
from pydantic import create_model, BaseConfig, Field
from pydantic.fields import FieldInfo
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import ColumnProperty, RelationshipProperty, Mapper
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.sql.schema import Column, ColumnDefault
from .utils import resolve_dotted_path
__all__ = [
'field_from_column',
'from_sqla'
]
FieldDef = tuple[type, FieldInfo]
_local_namespace = {}
class OrmConfig(BaseConfig):
orm_mode = True
def field_from_column(column: Column) -> FieldDef:
try:
field_type = column.type.impl.python_type
except AttributeError:
try:
field_type = column.type.python_type
except AttributeError:
raise AssertionError(f"Could not infer Python type for {column.key}")
default = ... if column.default is None and not column.nullable else column.default
if isinstance(default, ColumnDefault):
if default.is_scalar:
field_info = Field(default=default.arg)
else:
assert callable(default.arg)
dotted_path = default.arg.__module__ + '.' + default.arg.__name__
factory = resolve_dotted_path(dotted_path)
assert callable(factory)
field_info = Field(default_factory=factory)
else:
field_info = Field(default=default)
return field_type, field_info
def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False,
config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
add_fields: dict[str, FieldDef] = None):
assert isinstance(db_model, DeclarativeMeta)
assert not (incl_one_to_many and incl_many_to_one)
fields = {}
for attr in inspect(db_model).attrs:
if attr.key in exclude:
continue
if isinstance(attr, ColumnProperty):
assert len(attr.columns) == 1
column = attr.columns[0]
fields[attr.key] = field_from_column(column)
elif isinstance(attr, RelationshipProperty):
related = attr.mapper
assert isinstance(related, Mapper)
if incl_many_to_one and attr.direction.name == 'MANYTOONE':
fields[attr.key] = (related.class_.__name__, Field(default=None))
if incl_one_to_many and attr.direction.name == 'ONETOMANY':
fields[attr.key] = (list[related.class_.__name__], Field(default=None))
else:
raise AssertionError("Unknown attr type", attr)
if add_fields is not None:
fields |= add_fields
name = db_model.__name__
pydantic_model = create_model(name, __config__=config, **fields)
pydantic_model.__name__ = name
pydantic_model.update_forward_refs(**_local_namespace)
_local_namespace[name] = pydantic_model
return pydantic_model

20
src/orm2pydantic/utils.py Normal file
View File

@ -0,0 +1,20 @@
from importlib import import_module
def resolve_dotted_path(dotted_path: str) -> object:
"""
Resolves a dotted path to a global object and returns that object.
Algorithm shamelessly stolen from the `logging.config` module from the standard library.
"""
names = dotted_path.split('.')
module_name = names.pop(0)
found = import_module(module_name)
for name in names:
try:
found = getattr(found, name)
except AttributeError:
module_name += f'.{name}'
import_module(module_name)
found = getattr(found, name)
return found

93
tests/informal.py Normal file
View File

@ -0,0 +1,93 @@
from sqlalchemy.engine.create import create_engine
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.functions import now as db_now
from sqlalchemy.sql.schema import Column, ForeignKey as FKey
from sqlalchemy.sql.sqltypes import Integer, String, TIMESTAMP, Unicode
from orm2pydantic.sqla import from_sqla
ORMBase = declarative_base()
engine = create_engine("sqlite://")
def default_factory() -> str: return '1'
class AbstractBase(ORMBase):
__abstract__ = True
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
class StateProvince(AbstractBase):
__tablename__ = 'state_province'
id = Column(Integer, primary_key=True)
country = Column(String(2), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
cities = relationship('City', backref='state_province', lazy='selectin')
class City(AbstractBase):
__tablename__ = 'city'
id = Column(Integer, primary_key=True)
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True)
zip_code = Column(String(5), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
streets = relationship('Street', backref='city', lazy='selectin')
class Street(AbstractBase):
__tablename__ = 'street'
id = Column(Integer, primary_key=True)
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True)
name = Column(Unicode(255), nullable=False, index=True)
addresses = relationship('Address', backref='street', lazy='selectin')
class Address(AbstractBase):
__tablename__ = 'address'
id = Column(Integer, primary_key=True)
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True)
house_number = Column(String(8), nullable=False, default=default_factory)
supplement = Column(String(255))
def main_test() -> None:
AbstractBase.metadata.create_all(engine)
from_sqla(StateProvince)
from_sqla(City)
from_sqla(Street)
_PydanticAddress = from_sqla(Address)
with Session(engine) as session:
bavaria = StateProvince(country="de", name="Bavaria")
munich = City(zip_code='80333', name="Munich")
bavaria.cities.append(munich)
maximilian_street = Street(name="Maximilianstrasse")
munich.streets.append(maximilian_street)
some_address = Address()
maximilian_street.addresses.append(some_address)
session.add_all([bavaria, munich, maximilian_street, some_address])
session.commit()
address = _PydanticAddress.from_orm(some_address)
assert address.house_number == '1'
assert address.street.city.state_province.name == "Bavaria"
if __name__ == '__main__':
main_test()