generated from daniil-berg/boilerplate-py
Refactor and simplify
This commit is contained in:
parent
20272c4b4a
commit
19224ea34a
@ -13,6 +13,7 @@ from .utils import resolve_dotted_path
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'field_from_column',
|
'field_from_column',
|
||||||
|
'field_from_relationship',
|
||||||
'from_sqla'
|
'from_sqla'
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -27,7 +28,9 @@ class OrmConfig(BaseConfig):
|
|||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
def field_from_column(column: Column) -> FieldDef:
|
def field_from_column(col_prop: ColumnProperty) -> FieldDef:
|
||||||
|
assert len(col_prop.columns) == 1
|
||||||
|
column: Column = col_prop.columns[0]
|
||||||
try:
|
try:
|
||||||
field_type = column.type.impl.python_type
|
field_type = column.type.impl.python_type
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -50,26 +53,26 @@ def field_from_column(column: Column) -> FieldDef:
|
|||||||
return field_type, field_info
|
return field_type, field_info
|
||||||
|
|
||||||
|
|
||||||
def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False,
|
def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
|
||||||
config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
|
assert isinstance(rel_prop.mapper, Mapper)
|
||||||
add_fields: dict[str, FieldDef] = None):
|
if rel_prop.direction.name == 'MANYTOONE':
|
||||||
|
return rel_prop.mapper.class_.__name__, Field(default=None)
|
||||||
|
if rel_prop.direction.name == 'ONETOMANY':
|
||||||
|
return list[rel_prop.mapper.class_.__name__], Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
|
||||||
|
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None):
|
||||||
assert isinstance(db_model, DeclarativeMeta)
|
assert isinstance(db_model, DeclarativeMeta)
|
||||||
assert not (incl_one_to_many and incl_many_to_one), "Pydantic is unable to handle the circular relationship"
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for attr in inspect(db_model).attrs:
|
for attr in inspect(db_model).attrs:
|
||||||
if attr.key in exclude:
|
if attr.key in exclude:
|
||||||
continue
|
continue
|
||||||
if isinstance(attr, ColumnProperty):
|
if isinstance(attr, ColumnProperty):
|
||||||
assert len(attr.columns) == 1
|
fields[attr.key] = field_from_column(attr)
|
||||||
column = attr.columns[0]
|
|
||||||
fields[attr.key] = field_from_column(column)
|
|
||||||
elif isinstance(attr, RelationshipProperty):
|
elif isinstance(attr, RelationshipProperty):
|
||||||
related = attr.mapper
|
if incl_relationships:
|
||||||
assert isinstance(related, Mapper)
|
fields[attr.key] = field_from_relationship(attr)
|
||||||
if incl_many_to_one and attr.direction.name == 'MANYTOONE':
|
|
||||||
fields[attr.key] = (related.class_.__name__, Field(default=None))
|
|
||||||
if incl_one_to_many and attr.direction.name == 'ONETOMANY':
|
|
||||||
fields[attr.key] = (list[related.class_.__name__], Field(default=None))
|
|
||||||
else:
|
else:
|
||||||
raise AssertionError("Unknown attr type", attr)
|
raise AssertionError("Unknown attr type", attr)
|
||||||
if add_fields is not None:
|
if add_fields is not None:
|
||||||
|
@ -17,8 +17,6 @@ def default_factory() -> str: return '1'
|
|||||||
class AbstractBase(ORMBase):
|
class AbstractBase(ORMBase):
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
|
|
||||||
|
|
||||||
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
|
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||||
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||||
|
|
||||||
@ -37,7 +35,7 @@ class City(AbstractBase):
|
|||||||
__tablename__ = 'city'
|
__tablename__ = 'city'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True)
|
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||||
zip_code = Column(String(5), nullable=False, index=True)
|
zip_code = Column(String(5), nullable=False, index=True)
|
||||||
name = Column(Unicode(255), nullable=False, index=True)
|
name = Column(Unicode(255), nullable=False, index=True)
|
||||||
|
|
||||||
@ -48,7 +46,7 @@ class Street(AbstractBase):
|
|||||||
__tablename__ = 'street'
|
__tablename__ = 'street'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True)
|
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||||
name = Column(Unicode(255), nullable=False, index=True)
|
name = Column(Unicode(255), nullable=False, index=True)
|
||||||
|
|
||||||
addresses = relationship('Address', backref='street', lazy='selectin')
|
addresses = relationship('Address', backref='street', lazy='selectin')
|
||||||
@ -58,7 +56,7 @@ class Address(AbstractBase):
|
|||||||
__tablename__ = 'address'
|
__tablename__ = 'address'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True)
|
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||||
house_number = Column(String(8), nullable=False, default=default_factory)
|
house_number = Column(String(8), nullable=False, default=default_factory)
|
||||||
supplement = Column(String(255))
|
supplement = Column(String(255))
|
||||||
|
|
||||||
@ -67,9 +65,9 @@ def main_test() -> None:
|
|||||||
engine = create_engine("sqlite://")
|
engine = create_engine("sqlite://")
|
||||||
AbstractBase.metadata.create_all(engine)
|
AbstractBase.metadata.create_all(engine)
|
||||||
|
|
||||||
_PydanticStateProvince = from_sqla(StateProvince)
|
_PydanticStateProvince = from_sqla(StateProvince, exclude=['cities'])
|
||||||
_PydanticCity = from_sqla(City)
|
_PydanticCity = from_sqla(City, exclude=['streets'])
|
||||||
_PydanticStreet = from_sqla(Street)
|
_PydanticStreet = from_sqla(Street, exclude=['addresses'])
|
||||||
_PydanticAddress = from_sqla(Address)
|
_PydanticAddress = from_sqla(Address)
|
||||||
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
|
Loading…
Reference in New Issue
Block a user