generated from daniil-berg/boilerplate-py
	Implement first simple working function and tests
This commit is contained in:
		
							
								
								
									
										82
									
								
								src/orm2pydantic/sqla.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								src/orm2pydantic/sqla.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| from typing import Container, Type | ||||
|  | ||||
| from pydantic import create_model, BaseConfig, Field | ||||
| from pydantic.fields import FieldInfo | ||||
|  | ||||
| from sqlalchemy.inspection import inspect | ||||
| from sqlalchemy.orm import ColumnProperty, RelationshipProperty, Mapper | ||||
| from sqlalchemy.orm.decl_api import DeclarativeMeta | ||||
| from sqlalchemy.sql.schema import Column, ColumnDefault | ||||
|  | ||||
| from .utils import resolve_dotted_path | ||||
|  | ||||
|  | ||||
| __all__ = [ | ||||
|     'field_from_column', | ||||
|     'from_sqla' | ||||
| ] | ||||
|  | ||||
|  | ||||
| FieldDef = tuple[type, FieldInfo] | ||||
|  | ||||
|  | ||||
| _local_namespace = {} | ||||
|  | ||||
|  | ||||
| class OrmConfig(BaseConfig): | ||||
|     orm_mode = True | ||||
|  | ||||
|  | ||||
| def field_from_column(column: Column) -> FieldDef: | ||||
|     try: | ||||
|         field_type = column.type.impl.python_type | ||||
|     except AttributeError: | ||||
|         try: | ||||
|             field_type = column.type.python_type | ||||
|         except AttributeError: | ||||
|             raise AssertionError(f"Could not infer Python type for {column.key}") | ||||
|     default = ... if column.default is None and not column.nullable else column.default | ||||
|     if isinstance(default, ColumnDefault): | ||||
|         if default.is_scalar: | ||||
|             field_info = Field(default=default.arg) | ||||
|         else: | ||||
|             assert callable(default.arg) | ||||
|             dotted_path = default.arg.__module__ + '.' + default.arg.__name__ | ||||
|             factory = resolve_dotted_path(dotted_path) | ||||
|             assert callable(factory) | ||||
|             field_info = Field(default_factory=factory) | ||||
|     else: | ||||
|         field_info = Field(default=default) | ||||
|     return field_type, field_info | ||||
|  | ||||
|  | ||||
| def from_sqla(db_model: Type[DeclarativeMeta], incl_many_to_one: bool = True, incl_one_to_many: bool = False, | ||||
|               config: Type[BaseConfig] = OrmConfig, exclude: Container[str] = (), | ||||
|               add_fields: dict[str, FieldDef] = None): | ||||
|     assert isinstance(db_model, DeclarativeMeta) | ||||
|     assert not (incl_one_to_many and incl_many_to_one) | ||||
|     fields = {} | ||||
|     for attr in inspect(db_model).attrs: | ||||
|         if attr.key in exclude: | ||||
|             continue | ||||
|         if isinstance(attr, ColumnProperty): | ||||
|             assert len(attr.columns) == 1 | ||||
|             column = attr.columns[0] | ||||
|             fields[attr.key] = field_from_column(column) | ||||
|         elif isinstance(attr, RelationshipProperty): | ||||
|             related = attr.mapper | ||||
|             assert isinstance(related, Mapper) | ||||
|             if incl_many_to_one and attr.direction.name == 'MANYTOONE': | ||||
|                 fields[attr.key] = (related.class_.__name__, Field(default=None)) | ||||
|             if incl_one_to_many and attr.direction.name == 'ONETOMANY': | ||||
|                 fields[attr.key] = (list[related.class_.__name__], Field(default=None)) | ||||
|         else: | ||||
|             raise AssertionError("Unknown attr type", attr) | ||||
|     if add_fields is not None: | ||||
|         fields |= add_fields | ||||
|     name = db_model.__name__ | ||||
|     pydantic_model = create_model(name, __config__=config, **fields) | ||||
|     pydantic_model.__name__ = name | ||||
|     pydantic_model.update_forward_refs(**_local_namespace) | ||||
|     _local_namespace[name] = pydantic_model | ||||
|     return pydantic_model | ||||
							
								
								
									
										20
									
								
								src/orm2pydantic/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								src/orm2pydantic/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| from importlib import import_module | ||||
|  | ||||
|  | ||||
| def resolve_dotted_path(dotted_path: str) -> object: | ||||
|     """ | ||||
|     Resolves a dotted path to a global object and returns that object. | ||||
|  | ||||
|     Algorithm shamelessly stolen from the `logging.config` module from the standard library. | ||||
|     """ | ||||
|     names = dotted_path.split('.') | ||||
|     module_name = names.pop(0) | ||||
|     found = import_module(module_name) | ||||
|     for name in names: | ||||
|         try: | ||||
|             found = getattr(found, name) | ||||
|         except AttributeError: | ||||
|             module_name += f'.{name}' | ||||
|             import_module(module_name) | ||||
|             found = getattr(found, name) | ||||
|     return found | ||||
							
								
								
									
										93
									
								
								tests/informal.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								tests/informal.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| from sqlalchemy.engine.create import create_engine | ||||
| from sqlalchemy.orm import declarative_base, relationship | ||||
| from sqlalchemy.orm.session import Session | ||||
| from sqlalchemy.sql.functions import now as db_now | ||||
| from sqlalchemy.sql.schema import Column, ForeignKey as FKey | ||||
| from sqlalchemy.sql.sqltypes import Integer, String, TIMESTAMP, Unicode | ||||
|  | ||||
| from orm2pydantic.sqla import from_sqla | ||||
|  | ||||
|  | ||||
| ORMBase = declarative_base() | ||||
| engine = create_engine("sqlite://") | ||||
|  | ||||
|  | ||||
| def default_factory() -> str: return '1' | ||||
|  | ||||
|  | ||||
| class AbstractBase(ORMBase): | ||||
|     __abstract__ = True | ||||
|  | ||||
|     NON_REPR_FIELDS = ['id', 'date_created', 'date_updated'] | ||||
|  | ||||
|     date_created = Column(TIMESTAMP(timezone=False), server_default=db_now()) | ||||
|     date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now()) | ||||
|  | ||||
|  | ||||
| class StateProvince(AbstractBase): | ||||
|     __tablename__ = 'state_province' | ||||
|  | ||||
|     id = Column(Integer, primary_key=True) | ||||
|     country = Column(String(2), nullable=False, index=True) | ||||
|     name = Column(Unicode(255), nullable=False, index=True) | ||||
|  | ||||
|     cities = relationship('City', backref='state_province', lazy='selectin') | ||||
|  | ||||
|  | ||||
| class City(AbstractBase): | ||||
|     __tablename__ = 'city' | ||||
|  | ||||
|     id = Column(Integer, primary_key=True) | ||||
|     state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), index=True) | ||||
|     zip_code = Column(String(5), nullable=False, index=True) | ||||
|     name = Column(Unicode(255), nullable=False, index=True) | ||||
|  | ||||
|     streets = relationship('Street', backref='city', lazy='selectin') | ||||
|  | ||||
|  | ||||
| class Street(AbstractBase): | ||||
|     __tablename__ = 'street' | ||||
|  | ||||
|     id = Column(Integer, primary_key=True) | ||||
|     city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), index=True) | ||||
|     name = Column(Unicode(255), nullable=False, index=True) | ||||
|  | ||||
|     addresses = relationship('Address', backref='street', lazy='selectin') | ||||
|  | ||||
|  | ||||
| class Address(AbstractBase): | ||||
|     __tablename__ = 'address' | ||||
|  | ||||
|     id = Column(Integer, primary_key=True) | ||||
|     street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), index=True) | ||||
|     house_number = Column(String(8), nullable=False, default=default_factory) | ||||
|     supplement = Column(String(255)) | ||||
|  | ||||
|  | ||||
| def main_test() -> None: | ||||
|     AbstractBase.metadata.create_all(engine) | ||||
|  | ||||
|     from_sqla(StateProvince) | ||||
|     from_sqla(City) | ||||
|     from_sqla(Street) | ||||
|     _PydanticAddress = from_sqla(Address) | ||||
|  | ||||
|     with Session(engine) as session: | ||||
|         bavaria = StateProvince(country="de", name="Bavaria") | ||||
|         munich = City(zip_code='80333', name="Munich") | ||||
|         bavaria.cities.append(munich) | ||||
|         maximilian_street = Street(name="Maximilianstrasse") | ||||
|         munich.streets.append(maximilian_street) | ||||
|         some_address = Address() | ||||
|         maximilian_street.addresses.append(some_address) | ||||
|         session.add_all([bavaria, munich, maximilian_street, some_address]) | ||||
|         session.commit() | ||||
|  | ||||
|         address = _PydanticAddress.from_orm(some_address) | ||||
|  | ||||
|     assert address.house_number == '1' | ||||
|     assert address.street.city.state_province.name == "Bavaria" | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main_test() | ||||
		Reference in New Issue
	
	Block a user