Write a few SQLModels;

set up two basic CRUD-like functions;
create two test routes;
make `db_uri` setting more flexible
This commit is contained in:
Daniil Fajnberg 2022-08-10 21:50:00 +02:00
parent 60c6b1c4bb
commit 9f487d515d
13 changed files with 435 additions and 236 deletions

View File

@ -1,7 +1,8 @@
Pydantic
FastAPI
SQLAlchemy[asyncio]
SQLAlchemy[asyncio]==1.4.35
Alembic
SQLAlchemy-Utils
SQLModel
Babel
python-slugify
Python-Slugify

View File

@ -21,15 +21,16 @@ classifiers =
package_dir =
= src
packages = find:
python_requires = >=3.10
python_requires = >=3.10, <4
install_requires =
Pydantic
FastAPI
SQLAlchemy[asyncio]
SQLAlchemy[asyncio]==1.4.35
Alembic
SQLAlchemy-Utils
SQLModel
Babel
python-slugify
Python-Slugify
[options.extras_require]
srv =

View File

@ -0,0 +1 @@
from .geography import *

View File

@ -0,0 +1,33 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.expression import select
from compub.models.geography import *
__all__ = [
# 'get_state_by_name_and_country',
'get_states',
'create_state'
]
# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince:
# statement = select(db.StateProvince).filter(
# db.StateProvince.name == name,
# db.StateProvince.country == country
# ).limit(1)
# result = await session.execute(statement)
# return result.scalars().first()
async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]:
statement = select(StateProvince).offset(skip).limit(limit)
result = await session.execute(statement)
return result.scalars().all()
async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince:
session.add(state)
await session.commit()
await session.refresh(state)
return state

View File

@ -1,35 +0,0 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql.functions import now as db_now
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import TIMESTAMP
from sqlalchemy_utils.functions.orm import get_columns
from sqlalchemy_utils.listeners import force_auto_coercion
from compub.exceptions import NoDatabaseConfigured
from compub.settings import settings
if settings.db_uri is None:
raise NoDatabaseConfigured
engine = create_async_engine(settings.db_uri, future=True)
LocalAsyncSession = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
ORMBase = declarative_base()
force_auto_coercion()
class AbstractBase(ORMBase):
__abstract__ = True
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
def __repr__(self) -> str:
# Exclude non-representative fields:
fields = (name for name in get_columns(self).keys() if name not in self.NON_REPR_FIELDS)
# Exclude NULL value fields:
attrs = ', '.join(f"{name}={repr(getattr(self, name))}" for name in fields if getattr(self, name) is not None)
return f"{self.__class__.__name__}({attrs})"

View File

@ -1,138 +0,0 @@
from datetime import date
from typing import Optional
from uuid import uuid4
from slugify import slugify
from sqlalchemy.engine import Connection
from sqlalchemy.event.api import listens_for
from sqlalchemy.orm import relationship
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table
from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode
from sqlalchemy_utils.types import CountryType, UUIDType
from compub.utils import multi_max
from .base import AbstractBase, ORMBase
__all__ = [
'LegalForm',
'LegalFormSubcategory',
'Integer',
'Company',
'CompanyName',
]
class LegalForm(AbstractBase):
__tablename__ = 'legal_form'
id = Column(Integer, primary_key=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
country = Column(CountryType)
subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class LegalFormSubcategory(AbstractBase):
__tablename__ = 'legal_form_subcategory'
id = Column(Integer, primary_key=True)
legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True)
short = Column(String(32), nullable=False, index=True)
name = Column(Unicode(255))
companies = relationship('Company', backref='legal_form', lazy='selectin')
def __str__(self) -> str:
return str(self.short)
class Industry(AbstractBase):
__tablename__ = 'industry'
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False, index=True)
companies = relationship('Company', secondary='company_industries', back_populates='industries')
def __str__(self) -> str:
return str(self.name)
class Company(AbstractBase):
__tablename__ = 'company'
NAME_SORTING_PARAMS = ( # passed to `multi_max`
lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered,
'date_updated',
)
# If date_registered is None, the name is considered to be older than one with a date_registered
id = Column(UUIDType, primary_key=True, default=uuid4)
visible = Column(Boolean, default=True, nullable=False, index=True)
legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True)
insolvent = Column(Boolean, default=False, nullable=False, index=True)
founding_date = Column(Date)
liquidation_date = Column(Date)
# TODO: Get rid of city; implement address properly
city = Column(String(255), index=True)
address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True)
industries = relationship('Industry', secondary='company_industries', back_populates='companies')
names = relationship('CompanyName', backref='company', lazy='selectin')
def __str__(self) -> str:
return str(self.current_name or f"<Company {self.id}>")
@property
def current_name(self) -> Optional['CompanyName']:
return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None)
company_industries = Table(
'company_industries',
ORMBase.metadata,
Column('company_id', FKey('company.id'), primary_key=True),
Column('industry_id', FKey('industry.id'), primary_key=True),
)
class CompanyName(AbstractBase):
__tablename__ = 'company_name'
MAX_SLUG_LENGTH = 255
id = Column(UUIDType, primary_key=True, default=uuid4)
name = Column(Unicode(768), nullable=False, index=True)
company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True)
date_registered = Column(Date)
slug = Column(String(MAX_SLUG_LENGTH), index=True)
def __str__(self) -> str:
return str(self.name)
@listens_for(CompanyName, 'before_insert')
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
if target.slug:
return
slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)]
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
num = connection.execute(statement).scalar()
if num == 0:
target.slug = slug
else:
target.slug = f'{slug}-{str(num + 1)}'
def get_reg_date(obj: CompanyName) -> date:
if obj.date_registered is None:
return date(1, 1, 1)
return obj.date_registered

View File

@ -1,56 +0,0 @@
from uuid import uuid4
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import Column, ForeignKey as FKey
from sqlalchemy.sql.sqltypes import Integer, String, Unicode
from sqlalchemy_utils.types import CountryType, UUIDType
from .base import AbstractBase
__all__ = [
'StateProvince',
'City',
'Street',
'Address',
]
class StateProvince(AbstractBase):
__tablename__ = 'state_province'
id = Column(Integer, primary_key=True)
country = Column(CountryType, nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
cities = relationship('City', backref='state_province', lazy='selectin')
class City(AbstractBase):
__tablename__ = 'city'
id = Column(Integer, primary_key=True)
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
zip_code = Column(String(5), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
streets = relationship('Street', backref='city', lazy='selectin')
class Street(AbstractBase):
__tablename__ = 'street'
id = Column(Integer, primary_key=True)
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
name = Column(Unicode(255), nullable=False, index=True)
addresses = relationship('Address', backref='street', lazy='selectin')
class Address(AbstractBase):
__tablename__ = 'address'
id = Column(UUIDType, primary_key=True, default=uuid4)
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
house_number = Column(String(8), nullable=False)
supplement = Column(String(255))

View File

@ -1,2 +1,3 @@
from .base import DB
from .geography import *
from .companies import *

80
src/compub/models/base.py Normal file
View File

@ -0,0 +1,80 @@
from datetime import datetime
from typing import Any, Iterator, Optional
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.future.engine import Engine
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql.functions import now as db_now
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import TIMESTAMP
from sqlalchemy_utils.functions.orm import get_columns
from sqlalchemy_utils.listeners import force_auto_coercion
from sqlmodel.main import Field, SQLModel
from compub.exceptions import NoDatabaseConfigured
from compub.settings import settings
from sqlmodel import create_engine, Session
class DB:
def __init__(self):
self._engine: AsyncEngine | Engine | None = None
self._session_maker: sessionmaker | None = None
def start_engine(self) -> None:
if settings.db_uri is None:
raise NoDatabaseConfigured
if settings.db_uri.scheme == 'sqlite':
self._engine = create_engine(settings.db_uri)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
else:
self._engine = create_async_engine(settings.db_uri, future=True)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
force_auto_coercion()
@property
def engine(self):
if self._engine is None:
self.start_engine()
assert isinstance(self._engine, (AsyncEngine, Engine))
return self._engine
async def get_session(self) -> Session | AsyncSession:
if self._session_maker is None:
self.start_engine()
assert isinstance(self._session_maker, sessionmaker)
session = self._session_maker()
try:
yield session
finally:
await session.close()
class AbstractBase(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
date_created: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
)
date_updated: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
)
def __repr__(self) -> str:
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
return f"{self.__class__.__name__}({attrs})"
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
for name in get_columns(self).keys():
if excl_non_repr and name in self.get_non_repr_fields():
continue
value = getattr(self, name)
if excl_null_value and value is None:
continue
yield name, value
@staticmethod
def get_non_repr_fields() -> list[str]:
return ['id', 'date_created', 'date_updated']

View File

@ -0,0 +1,170 @@
from datetime import date
from typing import Optional
from slugify import slugify
from sqlalchemy.engine.base import Connection
from sqlalchemy.event.api import listens_for
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.types import CountryType
from sqlmodel.main import Field, Relationship
from compub.utils import multi_max
from .base import AbstractBase
__all__ = [
'LegalForm',
'LegalFormSubcategory',
'Industry',
'Company',
'CompanyName',
]
class LegalForm(AbstractBase, table=True):
__tablename__ = 'legal_form'
# Fields
short: str = Field(max_length=32, nullable=False, index=True)
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
country: str = Field(sa_column=Column(CountryType))
# Relationships
subcategories: list['LegalFormSubcategory'] = Relationship(
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.short)
class LegalFormSubcategory(AbstractBase, table=True):
__tablename__ = 'legal_form_subcategory'
# Fields
short: str = Field(max_length=32, nullable=False, index=True)
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
# Relationships
legal_form_id: Optional[int] = Field(
foreign_key='legal_form.id', default=None, nullable=False, index=True
)
legal_form: Optional[LegalForm] = Relationship(
back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'}
)
companies: list['Company'] = Relationship(
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.short)
class CompanyIndustryLink(AbstractBase, table=True):
__tablename__ = 'company_industries'
# Relationships
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True)
class Industry(AbstractBase, table=True):
__tablename__ = 'industry'
# Fields
name: str = Field(max_length=255, nullable=False, index=True)
# Relationships
companies: list['Company'] = Relationship(
back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.name)
class Company(AbstractBase, table=True):
__tablename__ = 'company'
# Fields
visible: bool = Field(default=True, nullable=False, index=True)
insolvent: bool = Field(default=False, nullable=False, index=True)
founding_data: Optional[date]
liquidation_date: Optional[date]
city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly
# Relationships
legal_form_id: Optional[int] = Field(
foreign_key='legal_form_subcategory.id', default=None, index=True
)
legal_form: Optional[LegalFormSubcategory] = Relationship(
back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'}
)
address_id: Optional[int] = Field(
foreign_key='address.id', default=None, index=True
)
industries: list[Industry] = Relationship(
back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
)
names: list['CompanyName'] = Relationship(
back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.current_name or f"<Company {self.id}>")
@property
def current_name(self) -> Optional['CompanyName']:
return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None)
class CompanyName(AbstractBase, table=True):
__tablename__ = 'company_name'
__MAX_NAME_LENGTH__: int = 768
__MAX_SLUG_LENGTH__: int = 255
# Fields
name: str = Field(
max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True)
)
date_registered: Optional[date]
slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True)
# Relationships
company_id: Optional[int] = Field(
foreign_key='company.id', default=None, nullable=False, index=True
)
company: Optional[Company] = Relationship(
back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'}
)
def __str__(self) -> str:
return str(self.name)
def get_reg_date(self) -> date:
return date(1, 1, 1) if self.date_registered is None else self.date_registered
@property
def max_slug_length(self) -> int:
return self.__MAX_SLUG_LENGTH__
@listens_for(CompanyName, 'before_insert')
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
if target.slug:
return
slug = slugify(target.name)[:(target.max_slug_length - 2)]
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
num = connection.execute(statement).scalar()
if num == 0:
target.slug = slug
else:
target.slug = f'{slug}-{str(num + 1)}'

View File

@ -0,0 +1,94 @@
from typing import Any, Optional
from pydantic import validator
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.primitives.country import Country
from sqlalchemy_utils.types import CountryType
from sqlmodel.main import Field, Relationship
from .base import AbstractBase
__all__ = [
'StateProvince',
'City',
'Street',
'Address',
]
class StateProvince(AbstractBase, table=True):
__tablename__ = 'state_province'
# Fields
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
cities: list['City'] = Relationship(
back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'}
)
@validator('country', pre=True)
def country_as_uppercase_string(cls, v: Any) -> str:
if isinstance(v, Country):
return v.code
if isinstance(v, str):
return v.upper()
raise TypeError
class City(AbstractBase, table=True):
__tablename__ = 'city'
# Fields
zip_code: str = Field(max_length=5, nullable=False, index=True)
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
state_province_id: Optional[int] = Field(
foreign_key='state_province.id', default=None, nullable=False, index=True
)
state_province: Optional[StateProvince] = Relationship(
back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'}
)
streets: list['Street'] = Relationship(
back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'}
)
class Street(AbstractBase, table=True):
__tablename__ = 'street'
# Fields
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
# Relationships
city_id: Optional[int] = Field(
foreign_key='city.id', default=None, nullable=False, index=True
)
city: Optional[City] = Relationship(
back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'}
)
addresses: list['Address'] = Relationship(
back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'}
)
class Address(AbstractBase, table=True):
__tablename__ = 'address'
# Fields
house_number: str = Field(max_length=8, nullable=False)
supplement: str = Field(max_length=255)
# Relationships
street_id: Optional[int] = Field(
foreign_key='street.id', default=None, nullable=False, index=True
)
street: Optional[Street] = Relationship(
back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'}
)

43
src/compub/routes.py Normal file
View File

@ -0,0 +1,43 @@
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from .models import *
from .models.base import DB
from . import crud
api, db = FastAPI(), DB()
@api.on_event('startup')
async def initialize_db():
async with db.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
@api.on_event("shutdown")
async def close_connection_pool():
await db.engine.dispose()
@api.post("/states/", response_model=StateProvince)
async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)):
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
# if db_obj:
# raise HTTPException(status_code=400, detail="Name already registered")
return await crud.create_state(session, state=state)
@api.get("/states/", response_model=list[StateProvince])
async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)):
return await crud.get_states(session, skip=skip, limit=limit)
# @app.post("/cities/", response_model=City)
# async def create_city(city: City, session: AsyncSession = Depends(get_session)):
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
# if db_obj:
# raise HTTPException(status_code=400, detail="Name already registered")
# return await crud.create_state(session, state=state)

View File

@ -3,7 +3,7 @@ import logging.config
from pathlib import Path
from typing import Any, Callable, ClassVar
from pydantic import BaseModel, BaseSettings, PostgresDsn, validator
from pydantic import BaseModel, BaseSettings, AnyUrl, validator
from pydantic.env_settings import SettingsSourceCallable
from yaml import safe_load
@ -76,8 +76,12 @@ class ServerSettings(BaseModel):
uds: str | None = None
class DBURI(AnyUrl):
host_required = False
class Settings(AbstractBaseSettings):
db_uri: PostgresDsn | None = None
db_uri: DBURI | None = None
server: ServerSettings = ServerSettings()
log_config: dict | Path | None = None