generated from daniil-berg/boilerplate-py
Write a few SQLModels;
set up two basic CRUD-like functions; create two test routes; make `db_uri` setting more flexible
This commit is contained in:
parent
60c6b1c4bb
commit
9f487d515d
@ -1,7 +1,8 @@
|
||||
Pydantic
|
||||
FastAPI
|
||||
SQLAlchemy[asyncio]
|
||||
SQLAlchemy[asyncio]==1.4.35
|
||||
Alembic
|
||||
SQLAlchemy-Utils
|
||||
SQLModel
|
||||
Babel
|
||||
python-slugify
|
||||
Python-Slugify
|
@ -21,15 +21,16 @@ classifiers =
|
||||
package_dir =
|
||||
= src
|
||||
packages = find:
|
||||
python_requires = >=3.10
|
||||
python_requires = >=3.10, <4
|
||||
install_requires =
|
||||
Pydantic
|
||||
FastAPI
|
||||
SQLAlchemy[asyncio]
|
||||
SQLAlchemy[asyncio]==1.4.35
|
||||
Alembic
|
||||
SQLAlchemy-Utils
|
||||
SQLModel
|
||||
Babel
|
||||
python-slugify
|
||||
Python-Slugify
|
||||
|
||||
[options.extras_require]
|
||||
srv =
|
||||
|
1
src/compub/crud/__init__.py
Normal file
1
src/compub/crud/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .geography import *
|
33
src/compub/crud/geography.py
Normal file
33
src/compub/crud/geography.py
Normal file
@ -0,0 +1,33 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql.expression import select
|
||||
|
||||
from compub.models.geography import *
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 'get_state_by_name_and_country',
|
||||
'get_states',
|
||||
'create_state'
|
||||
]
|
||||
|
||||
|
||||
# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince:
|
||||
# statement = select(db.StateProvince).filter(
|
||||
# db.StateProvince.name == name,
|
||||
# db.StateProvince.country == country
|
||||
# ).limit(1)
|
||||
# result = await session.execute(statement)
|
||||
# return result.scalars().first()
|
||||
|
||||
|
||||
async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]:
|
||||
statement = select(StateProvince).offset(skip).limit(limit)
|
||||
result = await session.execute(statement)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince:
|
||||
session.add(state)
|
||||
await session.commit()
|
||||
await session.refresh(state)
|
||||
return state
|
@ -1,35 +0,0 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.sql.functions import now as db_now
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
||||
from sqlalchemy_utils.functions.orm import get_columns
|
||||
from sqlalchemy_utils.listeners import force_auto_coercion
|
||||
|
||||
from compub.exceptions import NoDatabaseConfigured
|
||||
from compub.settings import settings
|
||||
|
||||
|
||||
if settings.db_uri is None:
|
||||
raise NoDatabaseConfigured
|
||||
engine = create_async_engine(settings.db_uri, future=True)
|
||||
LocalAsyncSession = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
ORMBase = declarative_base()
|
||||
|
||||
force_auto_coercion()
|
||||
|
||||
|
||||
class AbstractBase(ORMBase):
|
||||
__abstract__ = True
|
||||
|
||||
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
|
||||
|
||||
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Exclude non-representative fields:
|
||||
fields = (name for name in get_columns(self).keys() if name not in self.NON_REPR_FIELDS)
|
||||
# Exclude NULL value fields:
|
||||
attrs = ', '.join(f"{name}={repr(getattr(self, name))}" for name in fields if getattr(self, name) is not None)
|
||||
return f"{self.__class__.__name__}({attrs})"
|
@ -1,138 +0,0 @@
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from slugify import slugify
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.event.api import listens_for
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.sql.expression import select
|
||||
from sqlalchemy.sql.functions import count
|
||||
from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table
|
||||
from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode
|
||||
from sqlalchemy_utils.types import CountryType, UUIDType
|
||||
|
||||
from compub.utils import multi_max
|
||||
from .base import AbstractBase, ORMBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LegalForm',
|
||||
'LegalFormSubcategory',
|
||||
'Integer',
|
||||
'Company',
|
||||
'CompanyName',
|
||||
]
|
||||
|
||||
|
||||
class LegalForm(AbstractBase):
|
||||
__tablename__ = 'legal_form'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
short = Column(String(32), nullable=False, index=True)
|
||||
name = Column(Unicode(255))
|
||||
country = Column(CountryType)
|
||||
|
||||
subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class LegalFormSubcategory(AbstractBase):
|
||||
__tablename__ = 'legal_form_subcategory'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
short = Column(String(32), nullable=False, index=True)
|
||||
name = Column(Unicode(255))
|
||||
|
||||
companies = relationship('Company', backref='legal_form', lazy='selectin')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class Industry(AbstractBase):
|
||||
__tablename__ = 'industry'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
|
||||
companies = relationship('Company', secondary='company_industries', back_populates='industries')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
|
||||
class Company(AbstractBase):
|
||||
__tablename__ = 'company'
|
||||
|
||||
NAME_SORTING_PARAMS = ( # passed to `multi_max`
|
||||
lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered,
|
||||
'date_updated',
|
||||
)
|
||||
# If date_registered is None, the name is considered to be older than one with a date_registered
|
||||
|
||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
||||
visible = Column(Boolean, default=True, nullable=False, index=True)
|
||||
legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True)
|
||||
insolvent = Column(Boolean, default=False, nullable=False, index=True)
|
||||
founding_date = Column(Date)
|
||||
liquidation_date = Column(Date)
|
||||
# TODO: Get rid of city; implement address properly
|
||||
city = Column(String(255), index=True)
|
||||
address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True)
|
||||
|
||||
industries = relationship('Industry', secondary='company_industries', back_populates='companies')
|
||||
names = relationship('CompanyName', backref='company', lazy='selectin')
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.current_name or f"<Company {self.id}>")
|
||||
|
||||
@property
|
||||
def current_name(self) -> Optional['CompanyName']:
|
||||
return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None)
|
||||
|
||||
|
||||
company_industries = Table(
|
||||
'company_industries',
|
||||
ORMBase.metadata,
|
||||
Column('company_id', FKey('company.id'), primary_key=True),
|
||||
Column('industry_id', FKey('industry.id'), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
class CompanyName(AbstractBase):
|
||||
__tablename__ = 'company_name'
|
||||
|
||||
MAX_SLUG_LENGTH = 255
|
||||
|
||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
||||
name = Column(Unicode(768), nullable=False, index=True)
|
||||
company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True)
|
||||
date_registered = Column(Date)
|
||||
slug = Column(String(MAX_SLUG_LENGTH), index=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
|
||||
@listens_for(CompanyName, 'before_insert')
|
||||
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
|
||||
if target.slug:
|
||||
return
|
||||
slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)]
|
||||
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
|
||||
num = connection.execute(statement).scalar()
|
||||
if num == 0:
|
||||
target.slug = slug
|
||||
else:
|
||||
target.slug = f'{slug}-{str(num + 1)}'
|
||||
|
||||
|
||||
def get_reg_date(obj: CompanyName) -> date:
|
||||
if obj.date_registered is None:
|
||||
return date(1, 1, 1)
|
||||
return obj.date_registered
|
@ -1,56 +0,0 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql.schema import Column, ForeignKey as FKey
|
||||
from sqlalchemy.sql.sqltypes import Integer, String, Unicode
|
||||
from sqlalchemy_utils.types import CountryType, UUIDType
|
||||
|
||||
from .base import AbstractBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StateProvince',
|
||||
'City',
|
||||
'Street',
|
||||
'Address',
|
||||
]
|
||||
|
||||
|
||||
class StateProvince(AbstractBase):
|
||||
__tablename__ = 'state_province'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
country = Column(CountryType, nullable=False, index=True)
|
||||
name = Column(Unicode(255), nullable=False, index=True)
|
||||
|
||||
cities = relationship('City', backref='state_province', lazy='selectin')
|
||||
|
||||
|
||||
class City(AbstractBase):
|
||||
__tablename__ = 'city'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
zip_code = Column(String(5), nullable=False, index=True)
|
||||
name = Column(Unicode(255), nullable=False, index=True)
|
||||
|
||||
streets = relationship('Street', backref='city', lazy='selectin')
|
||||
|
||||
|
||||
class Street(AbstractBase):
|
||||
__tablename__ = 'street'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
name = Column(Unicode(255), nullable=False, index=True)
|
||||
|
||||
addresses = relationship('Address', backref='street', lazy='selectin')
|
||||
|
||||
|
||||
class Address(AbstractBase):
|
||||
__tablename__ = 'address'
|
||||
|
||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
||||
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
|
||||
house_number = Column(String(8), nullable=False)
|
||||
supplement = Column(String(255))
|
@ -1,2 +1,3 @@
|
||||
from .base import DB
|
||||
from .geography import *
|
||||
from .companies import *
|
80
src/compub/models/base.py
Normal file
80
src/compub/models/base.py
Normal file
@ -0,0 +1,80 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterator, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.future.engine import Engine
|
||||
from sqlalchemy.orm.session import sessionmaker
|
||||
from sqlalchemy.sql.functions import now as db_now
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
||||
from sqlalchemy_utils.functions.orm import get_columns
|
||||
from sqlalchemy_utils.listeners import force_auto_coercion
|
||||
from sqlmodel.main import Field, SQLModel
|
||||
|
||||
from compub.exceptions import NoDatabaseConfigured
|
||||
from compub.settings import settings
|
||||
|
||||
from sqlmodel import create_engine, Session
|
||||
|
||||
|
||||
class DB:
|
||||
def __init__(self):
|
||||
self._engine: AsyncEngine | Engine | None = None
|
||||
self._session_maker: sessionmaker | None = None
|
||||
|
||||
def start_engine(self) -> None:
|
||||
if settings.db_uri is None:
|
||||
raise NoDatabaseConfigured
|
||||
if settings.db_uri.scheme == 'sqlite':
|
||||
self._engine = create_engine(settings.db_uri)
|
||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
||||
else:
|
||||
self._engine = create_async_engine(settings.db_uri, future=True)
|
||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
||||
force_auto_coercion()
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if self._engine is None:
|
||||
self.start_engine()
|
||||
assert isinstance(self._engine, (AsyncEngine, Engine))
|
||||
return self._engine
|
||||
|
||||
async def get_session(self) -> Session | AsyncSession:
|
||||
if self._session_maker is None:
|
||||
self.start_engine()
|
||||
assert isinstance(self._session_maker, sessionmaker)
|
||||
session = self._session_maker()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
class AbstractBase(SQLModel):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
date_created: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||
)
|
||||
date_updated: Optional[datetime] = Field(
|
||||
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
|
||||
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
|
||||
return f"{self.__class__.__name__}({attrs})"
|
||||
|
||||
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
|
||||
for name in get_columns(self).keys():
|
||||
if excl_non_repr and name in self.get_non_repr_fields():
|
||||
continue
|
||||
value = getattr(self, name)
|
||||
if excl_null_value and value is None:
|
||||
continue
|
||||
yield name, value
|
||||
|
||||
@staticmethod
|
||||
def get_non_repr_fields() -> list[str]:
|
||||
return ['id', 'date_created', 'date_updated']
|
170
src/compub/models/companies.py
Normal file
170
src/compub/models/companies.py
Normal file
@ -0,0 +1,170 @@
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from slugify import slugify
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.event.api import listens_for
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.sql.expression import select
|
||||
from sqlalchemy.sql.functions import count
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import Unicode
|
||||
from sqlalchemy_utils.types import CountryType
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from compub.utils import multi_max
|
||||
from .base import AbstractBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LegalForm',
|
||||
'LegalFormSubcategory',
|
||||
'Industry',
|
||||
'Company',
|
||||
'CompanyName',
|
||||
]
|
||||
|
||||
|
||||
class LegalForm(AbstractBase, table=True):
|
||||
__tablename__ = 'legal_form'
|
||||
|
||||
# Fields
|
||||
short: str = Field(max_length=32, nullable=False, index=True)
|
||||
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||
country: str = Field(sa_column=Column(CountryType))
|
||||
|
||||
# Relationships
|
||||
subcategories: list['LegalFormSubcategory'] = Relationship(
|
||||
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class LegalFormSubcategory(AbstractBase, table=True):
|
||||
__tablename__ = 'legal_form_subcategory'
|
||||
|
||||
# Fields
|
||||
short: str = Field(max_length=32, nullable=False, index=True)
|
||||
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||
|
||||
# Relationships
|
||||
legal_form_id: Optional[int] = Field(
|
||||
foreign_key='legal_form.id', default=None, nullable=False, index=True
|
||||
)
|
||||
legal_form: Optional[LegalForm] = Relationship(
|
||||
back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
companies: list['Company'] = Relationship(
|
||||
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.short)
|
||||
|
||||
|
||||
class CompanyIndustryLink(AbstractBase, table=True):
|
||||
__tablename__ = 'company_industries'
|
||||
|
||||
# Relationships
|
||||
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
||||
industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True)
|
||||
|
||||
|
||||
class Industry(AbstractBase, table=True):
|
||||
__tablename__ = 'industry'
|
||||
|
||||
# Fields
|
||||
name: str = Field(max_length=255, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
companies: list['Company'] = Relationship(
|
||||
back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
|
||||
class Company(AbstractBase, table=True):
|
||||
__tablename__ = 'company'
|
||||
|
||||
# Fields
|
||||
visible: bool = Field(default=True, nullable=False, index=True)
|
||||
insolvent: bool = Field(default=False, nullable=False, index=True)
|
||||
founding_data: Optional[date]
|
||||
liquidation_date: Optional[date]
|
||||
city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly
|
||||
|
||||
# Relationships
|
||||
legal_form_id: Optional[int] = Field(
|
||||
foreign_key='legal_form_subcategory.id', default=None, index=True
|
||||
)
|
||||
legal_form: Optional[LegalFormSubcategory] = Relationship(
|
||||
back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
address_id: Optional[int] = Field(
|
||||
foreign_key='address.id', default=None, index=True
|
||||
)
|
||||
|
||||
industries: list[Industry] = Relationship(
|
||||
back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
names: list['CompanyName'] = Relationship(
|
||||
back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.current_name or f"<Company {self.id}>")
|
||||
|
||||
@property
|
||||
def current_name(self) -> Optional['CompanyName']:
|
||||
return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None)
|
||||
|
||||
|
||||
class CompanyName(AbstractBase, table=True):
|
||||
__tablename__ = 'company_name'
|
||||
__MAX_NAME_LENGTH__: int = 768
|
||||
__MAX_SLUG_LENGTH__: int = 255
|
||||
|
||||
# Fields
|
||||
name: str = Field(
|
||||
max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True)
|
||||
)
|
||||
date_registered: Optional[date]
|
||||
slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True)
|
||||
|
||||
# Relationships
|
||||
company_id: Optional[int] = Field(
|
||||
foreign_key='company.id', default=None, nullable=False, index=True
|
||||
)
|
||||
company: Optional[Company] = Relationship(
|
||||
back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
def get_reg_date(self) -> date:
|
||||
return date(1, 1, 1) if self.date_registered is None else self.date_registered
|
||||
|
||||
@property
|
||||
def max_slug_length(self) -> int:
|
||||
return self.__MAX_SLUG_LENGTH__
|
||||
|
||||
|
||||
@listens_for(CompanyName, 'before_insert')
|
||||
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
|
||||
if target.slug:
|
||||
return
|
||||
slug = slugify(target.name)[:(target.max_slug_length - 2)]
|
||||
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
|
||||
num = connection.execute(statement).scalar()
|
||||
if num == 0:
|
||||
target.slug = slug
|
||||
else:
|
||||
target.slug = f'{slug}-{str(num + 1)}'
|
94
src/compub/models/geography.py
Normal file
94
src/compub/models/geography.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import validator
|
||||
from sqlalchemy.sql.schema import Column
|
||||
from sqlalchemy.sql.sqltypes import Unicode
|
||||
from sqlalchemy_utils.primitives.country import Country
|
||||
from sqlalchemy_utils.types import CountryType
|
||||
from sqlmodel.main import Field, Relationship
|
||||
|
||||
from .base import AbstractBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
'StateProvince',
|
||||
'City',
|
||||
'Street',
|
||||
'Address',
|
||||
]
|
||||
|
||||
|
||||
class StateProvince(AbstractBase, table=True):
|
||||
__tablename__ = 'state_province'
|
||||
|
||||
# Fields
|
||||
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
cities: list['City'] = Relationship(
|
||||
back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
@validator('country', pre=True)
|
||||
def country_as_uppercase_string(cls, v: Any) -> str:
|
||||
if isinstance(v, Country):
|
||||
return v.code
|
||||
if isinstance(v, str):
|
||||
return v.upper()
|
||||
raise TypeError
|
||||
|
||||
|
||||
class City(AbstractBase, table=True):
|
||||
__tablename__ = 'city'
|
||||
|
||||
# Fields
|
||||
zip_code: str = Field(max_length=5, nullable=False, index=True)
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
state_province_id: Optional[int] = Field(
|
||||
foreign_key='state_province.id', default=None, nullable=False, index=True
|
||||
)
|
||||
state_province: Optional[StateProvince] = Relationship(
|
||||
back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
streets: list['Street'] = Relationship(
|
||||
back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
|
||||
class Street(AbstractBase, table=True):
|
||||
__tablename__ = 'street'
|
||||
|
||||
# Fields
|
||||
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||
|
||||
# Relationships
|
||||
city_id: Optional[int] = Field(
|
||||
foreign_key='city.id', default=None, nullable=False, index=True
|
||||
)
|
||||
city: Optional[City] = Relationship(
|
||||
back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
addresses: list['Address'] = Relationship(
|
||||
back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
||||
|
||||
|
||||
class Address(AbstractBase, table=True):
|
||||
__tablename__ = 'address'
|
||||
|
||||
# Fields
|
||||
house_number: str = Field(max_length=8, nullable=False)
|
||||
supplement: str = Field(max_length=255)
|
||||
|
||||
# Relationships
|
||||
street_id: Optional[int] = Field(
|
||||
foreign_key='street.id', default=None, nullable=False, index=True
|
||||
)
|
||||
street: Optional[Street] = Relationship(
|
||||
back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||
)
|
43
src/compub/routes.py
Normal file
43
src/compub/routes.py
Normal file
@ -0,0 +1,43 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from .models import *
|
||||
from .models.base import DB
|
||||
from . import crud
|
||||
|
||||
|
||||
api, db = FastAPI(), DB()
|
||||
|
||||
|
||||
@api.on_event('startup')
|
||||
async def initialize_db():
|
||||
async with db.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
@api.on_event("shutdown")
|
||||
async def close_connection_pool():
|
||||
await db.engine.dispose()
|
||||
|
||||
|
||||
@api.post("/states/", response_model=StateProvince)
|
||||
async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)):
|
||||
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||
# if db_obj:
|
||||
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||
return await crud.create_state(session, state=state)
|
||||
|
||||
|
||||
@api.get("/states/", response_model=list[StateProvince])
|
||||
async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)):
|
||||
return await crud.get_states(session, skip=skip, limit=limit)
|
||||
|
||||
|
||||
# @app.post("/cities/", response_model=City)
|
||||
# async def create_city(city: City, session: AsyncSession = Depends(get_session)):
|
||||
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||
# if db_obj:
|
||||
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||
# return await crud.create_state(session, state=state)
|
@ -3,7 +3,7 @@ import logging.config
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar
|
||||
|
||||
from pydantic import BaseModel, BaseSettings, PostgresDsn, validator
|
||||
from pydantic import BaseModel, BaseSettings, AnyUrl, validator
|
||||
from pydantic.env_settings import SettingsSourceCallable
|
||||
from yaml import safe_load
|
||||
|
||||
@ -76,8 +76,12 @@ class ServerSettings(BaseModel):
|
||||
uds: str | None = None
|
||||
|
||||
|
||||
class DBURI(AnyUrl):
|
||||
host_required = False
|
||||
|
||||
|
||||
class Settings(AbstractBaseSettings):
|
||||
db_uri: PostgresDsn | None = None
|
||||
db_uri: DBURI | None = None
|
||||
server: ServerSettings = ServerSettings()
|
||||
log_config: dict | Path | None = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user