generated from daniil-berg/boilerplate-py
Refactor and simplify
This commit is contained in:
parent
20272c4b4a
commit
19224ea34a
@ -13,6 +13,7 @@ from .utils import resolve_dotted_path
|
||||
|
||||
__all__ = [
|
||||
'field_from_column',
|
||||
'field_from_relationship',
|
||||
'from_sqla'
|
||||
]
|
||||
|
||||
@ -27,7 +28,9 @@ class OrmConfig(BaseConfig):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
def field_from_column(column: Column) -> FieldDef:
|
||||
def field_from_column(col_prop: ColumnProperty) -> FieldDef:
|
||||
assert len(col_prop.columns) == 1
|
||||
column: Column = col_prop.columns[0]
|
||||
try:
|
||||
field_type = column.type.impl.python_type
|
||||
except AttributeError:
|
||||
@ -50,26 +53,26 @@ def field_from_column(column: Column) -> FieldDef:
|
||||
return field_type, field_info
|
||||
|
||||
|
||||
def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False,
|
||||
config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
|
||||
add_fields: dict[str, FieldDef] = None):
|
||||
def field_from_relationship(rel_prop: RelationshipProperty) -> FieldDef:
|
||||
assert isinstance(rel_prop.mapper, Mapper)
|
||||
if rel_prop.direction.name == 'MANYTOONE':
|
||||
return rel_prop.mapper.class_.__name__, Field(default=None)
|
||||
if rel_prop.direction.name == 'ONETOMANY':
|
||||
return list[rel_prop.mapper.class_.__name__], Field(default=None)
|
||||
|
||||
|
||||
def from_sqla(db_model: Type[DeclarativeMeta], config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (),
|
||||
incl_relationships: bool = True, add_fields: dict[str, FieldDef] = None):
|
||||
assert isinstance(db_model, DeclarativeMeta)
|
||||
assert not (incl_one_to_many and incl_many_to_one), "Pydantic is unable to handle the circular relationship"
|
||||
fields = {}
|
||||
for attr in inspect(db_model).attrs:
|
||||
if attr.key in exclude:
|
||||
continue
|
||||
if isinstance(attr, ColumnProperty):
|
||||
assert len(attr.columns) == 1
|
||||
column = attr.columns[0]
|
||||
fields[attr.key] = field_from_column(column)
|
||||
fields[attr.key] = field_from_column(attr)
|
||||
elif isinstance(attr, RelationshipProperty):
|
||||
related = attr.mapper
|
||||
assert isinstance(related, Mapper)
|
||||
if incl_many_to_one and attr.direction.name == 'MANYTOONE':
|
||||
fields[attr.key] = (related.class_.__name__, Field(default=None))
|
||||
if incl_one_to_many and attr.direction.name == 'ONETOMANY':
|
||||
fields[attr.key] = (list[related.class_.__name__], Field(default=None))
|
||||
if incl_relationships:
|
||||
fields[attr.key] = field_from_relationship(attr)
|
||||
else:
|
||||
raise AssertionError("Unknown attr type", attr)
|
||||
if add_fields is not None:
|
||||
|
@ -17,8 +17,6 @@ def default_factory() -> str: return '1'
|
||||
class AbstractBase(ORMBase):
|
||||
__abstract__ = True
|
||||
|
||||
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
|
||||
|
||||
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||
|
||||
@ -37,7 +35,7 @@ class City(AbstractBase):
|
||||
__tablename__ = 'city'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True)
|
||||
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
zip_code = Column(String(5), nullable=False, index=True)
|
||||
name = Column(Unicode(255), nullable=False, index=True)
|
||||
|
||||
@ -48,7 +46,7 @@ class Street(AbstractBase):
|
||||
__tablename__ = 'street'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True)
|
||||
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
name = Column(Unicode(255), nullable=False, index=True)
|
||||
|
||||
addresses = relationship('Address', backref='street', lazy='selectin')
|
||||
@ -58,7 +56,7 @@ class Address(AbstractBase):
|
||||
__tablename__ = 'address'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True)
|
||||
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
house_number = Column(String(8), nullable=False, default=default_factory)
|
||||
supplement = Column(String(255))
|
||||
|
||||
@ -67,9 +65,9 @@ def main_test() -> None:
|
||||
engine = create_engine("sqlite://")
|
||||
AbstractBase.metadata.create_all(engine)
|
||||
|
||||
_PydanticStateProvince = from_sqla(StateProvince)
|
||||
_PydanticCity = from_sqla(City)
|
||||
_PydanticStreet = from_sqla(Street)
|
||||
_PydanticStateProvince = from_sqla(StateProvince, exclude=['cities'])
|
||||
_PydanticCity = from_sqla(City, exclude=['streets'])
|
||||
_PydanticStreet = from_sqla(Street, exclude=['addresses'])
|
||||
_PydanticAddress = from_sqla(Address)
|
||||
|
||||
with Session(engine) as session:
|
||||
|
Loading…
Reference in New Issue
Block a user