diff --git a/src/orm2pydantic/sqla.py b/src/orm2pydantic/sqla.py index 06d96d8..12e7feb 100644 --- a/src/orm2pydantic/sqla.py +++ b/src/orm2pydantic/sqla.py @@ -13,6 +13,7 @@ from .utils import resolve_dotted_path __all__ = [ 'field_from_column', + 'field_from_relationship', 'from_sqla' ] @@ -27,7 +28,9 @@ class OrmConfig(BaseConfig): orm_mode = True -def field_from_column(column: Column) -> FieldDef: +def field_from_column(col_prop: ColumnProperty) -> FieldDef: + assert len(col_prop.columns) == 1 + column: Column = col_prop.columns[0] try: field_type = column.type.impl.python_type except AttributeError: @@ -50,26 +53,26 @@ def field_from_column(column: Column) -> FieldDef: return field_type, field_info -def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False, - config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), - add_fields: dict[str, FieldDef] = None): +def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef: + assert isinstance(rel_prop.mapper, Mapper) + if rel_prop.direction.name == 'MANYTOONE': + return rel_prop.mapper.class_.__name__, Field(default=None) + if rel_prop.direction.name == 'ONETOMANY': + return list[rel_prop.mapper.class_.__name__], Field(default=None) + + +def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), + incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None): assert isinstance(db_model, DeclarativeMeta) - assert not (incl_one_to_many and incl_many_to_one), "Pydantic is unable to handle the circular relationship" fields = {} for attr in inspect(db_model).attrs: if attr.key in exclude: continue if isinstance(attr, ColumnProperty): - assert len(attr.columns) == 1 - column = attr.columns[0] - fields[attr.key] = field_from_column(column) + fields[attr.key] = field_from_column(attr) elif isinstance(attr, RelationshipProperty): - related = attr.mapper - assert isinstance(related, Mapper) - if incl_many_to_one and attr.direction.name == 'MANYTOONE': - fields[attr.key] = (related.class_.__name__, Field(default=None)) - if incl_one_to_many and attr.direction.name == 'ONETOMANY': - fields[attr.key] = (list[related.class_.__name__], Field(default=None)) + if incl_relationships: + fields[attr.key] = field_from_relationship(attr) else: raise AssertionError("Unknown attr type", attr) if add_fields is not None: diff --git a/tests/informal.py b/tests/informal.py index 641762d..dabdd14 100644 --- a/tests/informal.py +++ b/tests/informal.py @@ -17,8 +17,6 @@ def default_factory() -> str: return '1' class AbstractBase(ORMBase): __abstract__ = True - NON_REPR_FIELDS = ['id', 'date_created', 'date_updated'] - date_created = Column(TIMESTAMP(timezone=False), server_default=db_now()) date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) @@ -37,7 +35,7 @@ class City(AbstractBase): __tablename__ = 'city' id = Column(Integer, primary_key=True) - state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True) + state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True) zip_code = Column(String(5), nullable=False, index=True) name = Column(Unicode(255), nullable=False, index=True) @@ -48,7 +46,7 @@ class Street(AbstractBase): __tablename__ = 'street' id = Column(Integer, primary_key=True) - city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True) + city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True) name = Column(Unicode(255), nullable=False, index=True) addresses = relationship('Address', backref='street', lazy='selectin') @@ -58,7 +56,7 @@ class Address(AbstractBase): __tablename__ = 'address' id = Column(Integer, primary_key=True) - street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True) + street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True) house_number = Column(String(8), nullable=False, default=default_factory) supplement = Column(String(255)) @@ -67,9 +65,9 @@ def main_test() -> None: engine = create_engine("sqlite://") AbstractBase.metadata.create_all(engine) - _PydanticStateProvince = from_sqla(StateProvince) - _PydanticCity = from_sqla(City) - _PydanticStreet = from_sqla(Street) + _PydanticStateProvince = from_sqla(StateProvince, exclude=['cities']) + _PydanticCity = from_sqla(City, exclude=['streets']) + _PydanticStreet = from_sqla(Street, exclude=['addresses']) _PydanticAddress = from_sqla(Address) with Session(engine) as session: