Refactor and simplify

This commit is contained in:
Daniil Fajnberg 2022-08-06 10:15:00 +02:00
parent 20272c4b4a
commit 19224ea34a
2 changed files with 23 additions and 22 deletions

View File

@ -13,6 +13,7 @@ from .utils import resolve_dotted_path
__all__ = [ __all__ = [
'field_from_column', 'field_from_column',
'field_from_relationship',
'from_sqla' 'from_sqla'
] ]
@ -27,7 +28,9 @@ class OrmConfig(BaseConfig):
orm_mode = True orm_mode = True
def field_from_column(column: Column) -> FieldDef: def field_from_column(col_prop: ColumnProperty) -> FieldDef:
assert len(col_prop.columns) == 1
column: Column = col_prop.columns[0]
try: try:
field_type = column.type.impl.python_type field_type = column.type.impl.python_type
except AttributeError: except AttributeError:
@ -50,26 +53,26 @@ def field_from_column(column: Column) -> FieldDef:
return field_type, field_info return field_type, field_info
def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False, def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), assert isinstance(rel_prop.mapper, Mapper)
add_fields: dict[str, FieldDef] = None): if rel_prop.direction.name == 'MANYTOONE':
return rel_prop.mapper.class_.__name__, Field(default=None)
if rel_prop.direction.name == 'ONETOMANY':
return list[rel_prop.mapper.class_.__name__], Field(default=None)
def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None):
assert isinstance(db_model, DeclarativeMeta) assert isinstance(db_model, DeclarativeMeta)
assert not (incl_one_to_many and incl_many_to_one), "Pydantic is unable to handle the circular relationship"
fields = {} fields = {}
for attr in inspect(db_model).attrs: for attr in inspect(db_model).attrs:
if attr.key in exclude: if attr.key in exclude:
continue continue
if isinstance(attr, ColumnProperty): if isinstance(attr, ColumnProperty):
assert len(attr.columns) == 1 fields[attr.key] = field_from_column(attr)
column = attr.columns[0]
fields[attr.key] = field_from_column(column)
elif isinstance(attr, RelationshipProperty): elif isinstance(attr, RelationshipProperty):
related = attr.mapper if incl_relationships:
assert isinstance(related, Mapper) fields[attr.key] = field_from_relationship(attr)
if incl_many_to_one and attr.direction.name == 'MANYTOONE':
fields[attr.key] = (related.class_.__name__, Field(default=None))
if incl_one_to_many and attr.direction.name == 'ONETOMANY':
fields[attr.key] = (list[related.class_.__name__], Field(default=None))
else: else:
raise AssertionError("Unknown attr type", attr) raise AssertionError("Unknown attr type", attr)
if add_fields is not None: if add_fields is not None:

View File

@ -17,8 +17,6 @@ def default_factory() -> str: return '1'
class AbstractBase(ORMBase): class AbstractBase(ORMBase):
__abstract__ = True __abstract__ = True
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now()) date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
@ -37,7 +35,7 @@ class City(AbstractBase):
__tablename__ = 'city' __tablename__ = 'city'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True) state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
zip_code = Column(String(5), nullable=False, index=True) zip_code = Column(String(5), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True) name = Column(Unicode(255), nullable=False, index=True)
@ -48,7 +46,7 @@ class Street(AbstractBase):
__tablename__ = 'street' __tablename__ = 'street'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True) city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True) name = Column(Unicode(255), nullable=False, index=True)
addresses = relationship('Address', backref='street', lazy='selectin') addresses = relationship('Address', backref='street', lazy='selectin')
@ -58,7 +56,7 @@ class Address(AbstractBase):
__tablename__ = 'address' __tablename__ = 'address'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True) street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
house_number = Column(String(8), nullable=False, default=default_factory) house_number = Column(String(8), nullable=False, default=default_factory)
supplement = Column(String(255)) supplement = Column(String(255))
@ -67,9 +65,9 @@ def main_test() -> None:
engine = create_engine("sqlite://") engine = create_engine("sqlite://")
AbstractBase.metadata.create_all(engine) AbstractBase.metadata.create_all(engine)
_PydanticStateProvince = from_sqla(StateProvince) _PydanticStateProvince = from_sqla(StateProvince, exclude=['cities'])
_PydanticCity = from_sqla(City) _PydanticCity = from_sqla(City, exclude=['streets'])
_PydanticStreet = from_sqla(Street) _PydanticStreet = from_sqla(Street, exclude=['addresses'])
_PydanticAddress = from_sqla(Address) _PydanticAddress = from_sqla(Address)
with Session(engine) as session: with Session(engine) as session: