generated from daniil-berg/boilerplate-py
Write a few SQLModels;
set up two basic CRUD-like functions; create two test routes; make `db_uri` setting more flexible
This commit is contained in:
parent
60c6b1c4bb
commit
9f487d515d
@ -1,7 +1,8 @@
|
|||||||
Pydantic
|
Pydantic
|
||||||
FastAPI
|
FastAPI
|
||||||
SQLAlchemy[asyncio]
|
SQLAlchemy[asyncio]==1.4.35
|
||||||
Alembic
|
Alembic
|
||||||
SQLAlchemy-Utils
|
SQLAlchemy-Utils
|
||||||
|
SQLModel
|
||||||
Babel
|
Babel
|
||||||
python-slugify
|
Python-Slugify
|
@ -21,15 +21,16 @@ classifiers =
|
|||||||
package_dir =
|
package_dir =
|
||||||
= src
|
= src
|
||||||
packages = find:
|
packages = find:
|
||||||
python_requires = >=3.10
|
python_requires = >=3.10, <4
|
||||||
install_requires =
|
install_requires =
|
||||||
Pydantic
|
Pydantic
|
||||||
FastAPI
|
FastAPI
|
||||||
SQLAlchemy[asyncio]
|
SQLAlchemy[asyncio]==1.4.35
|
||||||
Alembic
|
Alembic
|
||||||
SQLAlchemy-Utils
|
SQLAlchemy-Utils
|
||||||
|
SQLModel
|
||||||
Babel
|
Babel
|
||||||
python-slugify
|
Python-Slugify
|
||||||
|
|
||||||
[options.extras_require]
|
[options.extras_require]
|
||||||
srv =
|
srv =
|
||||||
|
1
src/compub/crud/__init__.py
Normal file
1
src/compub/crud/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .geography import *
|
33
src/compub/crud/geography.py
Normal file
33
src/compub/crud/geography.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql.expression import select
|
||||||
|
|
||||||
|
from compub.models.geography import *
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 'get_state_by_name_and_country',
|
||||||
|
'get_states',
|
||||||
|
'create_state'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# async def get_state_by_name_and_country(session: AsyncSession, name: str, country: str) -> db.StateProvince:
|
||||||
|
# statement = select(db.StateProvince).filter(
|
||||||
|
# db.StateProvince.name == name,
|
||||||
|
# db.StateProvince.country == country
|
||||||
|
# ).limit(1)
|
||||||
|
# result = await session.execute(statement)
|
||||||
|
# return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_states(session: AsyncSession, skip: int = 0, limit: int = 100) -> list[StateProvince]:
|
||||||
|
statement = select(StateProvince).offset(skip).limit(limit)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_state(session: AsyncSession, state: StateProvince) -> StateProvince:
|
||||||
|
session.add(state)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(state)
|
||||||
|
return state
|
@ -1,35 +0,0 @@
|
|||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
||||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
|
||||||
from sqlalchemy.sql.functions import now as db_now
|
|
||||||
from sqlalchemy.sql.schema import Column
|
|
||||||
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
|
||||||
from sqlalchemy_utils.functions.orm import get_columns
|
|
||||||
from sqlalchemy_utils.listeners import force_auto_coercion
|
|
||||||
|
|
||||||
from compub.exceptions import NoDatabaseConfigured
|
|
||||||
from compub.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
if settings.db_uri is None:
|
|
||||||
raise NoDatabaseConfigured
|
|
||||||
engine = create_async_engine(settings.db_uri, future=True)
|
|
||||||
LocalAsyncSession = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
||||||
ORMBase = declarative_base()
|
|
||||||
|
|
||||||
force_auto_coercion()
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractBase(ORMBase):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
NON_REPR_FIELDS = ['id', 'date_created', 'date_updated']
|
|
||||||
|
|
||||||
date_created = Column(TIMESTAMP(timezone=False), server_default=db_now())
|
|
||||||
date_updated = Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
# Exclude non-representative fields:
|
|
||||||
fields = (name for name in get_columns(self).keys() if name not in self.NON_REPR_FIELDS)
|
|
||||||
# Exclude NULL value fields:
|
|
||||||
attrs = ', '.join(f"{name}={repr(getattr(self, name))}" for name in fields if getattr(self, name) is not None)
|
|
||||||
return f"{self.__class__.__name__}({attrs})"
|
|
@ -1,138 +0,0 @@
|
|||||||
from datetime import date
|
|
||||||
from typing import Optional
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from slugify import slugify
|
|
||||||
from sqlalchemy.engine import Connection
|
|
||||||
from sqlalchemy.event.api import listens_for
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from sqlalchemy.orm.mapper import Mapper
|
|
||||||
from sqlalchemy.sql.expression import select
|
|
||||||
from sqlalchemy.sql.functions import count
|
|
||||||
from sqlalchemy.sql.schema import Column, ForeignKey as FKey, Table
|
|
||||||
from sqlalchemy.sql.sqltypes import Boolean, Date, Integer, String, Unicode
|
|
||||||
from sqlalchemy_utils.types import CountryType, UUIDType
|
|
||||||
|
|
||||||
from compub.utils import multi_max
|
|
||||||
from .base import AbstractBase, ORMBase
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'LegalForm',
|
|
||||||
'LegalFormSubcategory',
|
|
||||||
'Integer',
|
|
||||||
'Company',
|
|
||||||
'CompanyName',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LegalForm(AbstractBase):
|
|
||||||
__tablename__ = 'legal_form'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
short = Column(String(32), nullable=False, index=True)
|
|
||||||
name = Column(Unicode(255))
|
|
||||||
country = Column(CountryType)
|
|
||||||
|
|
||||||
subcategories = relationship('LegalFormSubcategory', backref='legal_form', lazy='selectin')
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.short)
|
|
||||||
|
|
||||||
|
|
||||||
class LegalFormSubcategory(AbstractBase):
|
|
||||||
__tablename__ = 'legal_form_subcategory'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
legal_form_id = Column(Integer, FKey('legal_form.id', ondelete='RESTRICT'), nullable=False, index=True)
|
|
||||||
short = Column(String(32), nullable=False, index=True)
|
|
||||||
name = Column(Unicode(255))
|
|
||||||
|
|
||||||
companies = relationship('Company', backref='legal_form', lazy='selectin')
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.short)
|
|
||||||
|
|
||||||
|
|
||||||
class Industry(AbstractBase):
|
|
||||||
__tablename__ = 'industry'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String(255), nullable=False, index=True)
|
|
||||||
|
|
||||||
companies = relationship('Company', secondary='company_industries', back_populates='industries')
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class Company(AbstractBase):
|
|
||||||
__tablename__ = 'company'
|
|
||||||
|
|
||||||
NAME_SORTING_PARAMS = ( # passed to `multi_max`
|
|
||||||
lambda obj: date(1, 1, 1) if obj.date_registered is None else obj.date_registered,
|
|
||||||
'date_updated',
|
|
||||||
)
|
|
||||||
# If date_registered is None, the name is considered to be older than one with a date_registered
|
|
||||||
|
|
||||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
|
||||||
visible = Column(Boolean, default=True, nullable=False, index=True)
|
|
||||||
legal_form_id = Column(Integer, FKey('legal_form_subcategory.id', ondelete='RESTRICT'), index=True)
|
|
||||||
insolvent = Column(Boolean, default=False, nullable=False, index=True)
|
|
||||||
founding_date = Column(Date)
|
|
||||||
liquidation_date = Column(Date)
|
|
||||||
# TODO: Get rid of city; implement address properly
|
|
||||||
city = Column(String(255), index=True)
|
|
||||||
address_id = Column(UUIDType, FKey('address.id', ondelete='RESTRICT'), index=True)
|
|
||||||
|
|
||||||
industries = relationship('Industry', secondary='company_industries', back_populates='companies')
|
|
||||||
names = relationship('CompanyName', backref='company', lazy='selectin')
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.current_name or f"<Company {self.id}>")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_name(self) -> Optional['CompanyName']:
|
|
||||||
return multi_max(list(self.names), *self.NAME_SORTING_PARAMS, default=None)
|
|
||||||
|
|
||||||
|
|
||||||
company_industries = Table(
|
|
||||||
'company_industries',
|
|
||||||
ORMBase.metadata,
|
|
||||||
Column('company_id', FKey('company.id'), primary_key=True),
|
|
||||||
Column('industry_id', FKey('industry.id'), primary_key=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyName(AbstractBase):
|
|
||||||
__tablename__ = 'company_name'
|
|
||||||
|
|
||||||
MAX_SLUG_LENGTH = 255
|
|
||||||
|
|
||||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
|
||||||
name = Column(Unicode(768), nullable=False, index=True)
|
|
||||||
company_id = Column(UUIDType, FKey('company.id', ondelete='RESTRICT'), index=True)
|
|
||||||
date_registered = Column(Date)
|
|
||||||
slug = Column(String(MAX_SLUG_LENGTH), index=True)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self.name)
|
|
||||||
|
|
||||||
|
|
||||||
@listens_for(CompanyName, 'before_insert')
|
|
||||||
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
|
|
||||||
if target.slug:
|
|
||||||
return
|
|
||||||
slug = slugify(target.name)[:(target.MAX_SLUG_LENGTH - 2)]
|
|
||||||
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
|
|
||||||
num = connection.execute(statement).scalar()
|
|
||||||
if num == 0:
|
|
||||||
target.slug = slug
|
|
||||||
else:
|
|
||||||
target.slug = f'{slug}-{str(num + 1)}'
|
|
||||||
|
|
||||||
|
|
||||||
def get_reg_date(obj: CompanyName) -> date:
|
|
||||||
if obj.date_registered is None:
|
|
||||||
return date(1, 1, 1)
|
|
||||||
return obj.date_registered
|
|
@ -1,56 +0,0 @@
|
|||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from sqlalchemy.sql.schema import Column, ForeignKey as FKey
|
|
||||||
from sqlalchemy.sql.sqltypes import Integer, String, Unicode
|
|
||||||
from sqlalchemy_utils.types import CountryType, UUIDType
|
|
||||||
|
|
||||||
from .base import AbstractBase
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'StateProvince',
|
|
||||||
'City',
|
|
||||||
'Street',
|
|
||||||
'Address',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class StateProvince(AbstractBase):
|
|
||||||
__tablename__ = 'state_province'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
country = Column(CountryType, nullable=False, index=True)
|
|
||||||
name = Column(Unicode(255), nullable=False, index=True)
|
|
||||||
|
|
||||||
cities = relationship('City', backref='state_province', lazy='selectin')
|
|
||||||
|
|
||||||
|
|
||||||
class City(AbstractBase):
|
|
||||||
__tablename__ = 'city'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
state_province_id = Column(Integer, FKey('state_province.id', ondelete='RESTRICT'), nullable=False, index=True)
|
|
||||||
zip_code = Column(String(5), nullable=False, index=True)
|
|
||||||
name = Column(Unicode(255), nullable=False, index=True)
|
|
||||||
|
|
||||||
streets = relationship('Street', backref='city', lazy='selectin')
|
|
||||||
|
|
||||||
|
|
||||||
class Street(AbstractBase):
|
|
||||||
__tablename__ = 'street'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
city_id = Column(Integer, FKey('city.id', ondelete='RESTRICT'), nullable=False, index=True)
|
|
||||||
name = Column(Unicode(255), nullable=False, index=True)
|
|
||||||
|
|
||||||
addresses = relationship('Address', backref='street', lazy='selectin')
|
|
||||||
|
|
||||||
|
|
||||||
class Address(AbstractBase):
|
|
||||||
__tablename__ = 'address'
|
|
||||||
|
|
||||||
id = Column(UUIDType, primary_key=True, default=uuid4)
|
|
||||||
street_id = Column(Integer, FKey('street.id', ondelete='RESTRICT'), nullable=False, index=True)
|
|
||||||
house_number = Column(String(8), nullable=False)
|
|
||||||
supplement = Column(String(255))
|
|
@ -1,2 +1,3 @@
|
|||||||
|
from .base import DB
|
||||||
from .geography import *
|
from .geography import *
|
||||||
from .companies import *
|
from .companies import *
|
80
src/compub/models/base.py
Normal file
80
src/compub/models/base.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Iterator, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
|
||||||
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlalchemy.future.engine import Engine
|
||||||
|
from sqlalchemy.orm.session import sessionmaker
|
||||||
|
from sqlalchemy.sql.functions import now as db_now
|
||||||
|
from sqlalchemy.sql.schema import Column
|
||||||
|
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
||||||
|
from sqlalchemy_utils.functions.orm import get_columns
|
||||||
|
from sqlalchemy_utils.listeners import force_auto_coercion
|
||||||
|
from sqlmodel.main import Field, SQLModel
|
||||||
|
|
||||||
|
from compub.exceptions import NoDatabaseConfigured
|
||||||
|
from compub.settings import settings
|
||||||
|
|
||||||
|
from sqlmodel import create_engine, Session
|
||||||
|
|
||||||
|
|
||||||
|
class DB:
|
||||||
|
def __init__(self):
|
||||||
|
self._engine: AsyncEngine | Engine | None = None
|
||||||
|
self._session_maker: sessionmaker | None = None
|
||||||
|
|
||||||
|
def start_engine(self) -> None:
|
||||||
|
if settings.db_uri is None:
|
||||||
|
raise NoDatabaseConfigured
|
||||||
|
if settings.db_uri.scheme == 'sqlite':
|
||||||
|
self._engine = create_engine(settings.db_uri)
|
||||||
|
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
||||||
|
else:
|
||||||
|
self._engine = create_async_engine(settings.db_uri, future=True)
|
||||||
|
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
||||||
|
force_auto_coercion()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self):
|
||||||
|
if self._engine is None:
|
||||||
|
self.start_engine()
|
||||||
|
assert isinstance(self._engine, (AsyncEngine, Engine))
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
async def get_session(self) -> Session | AsyncSession:
|
||||||
|
if self._session_maker is None:
|
||||||
|
self.start_engine()
|
||||||
|
assert isinstance(self._session_maker, sessionmaker)
|
||||||
|
session = self._session_maker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractBase(SQLModel):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
date_created: Optional[datetime] = Field(
|
||||||
|
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
|
||||||
|
)
|
||||||
|
date_updated: Optional[datetime] = Field(
|
||||||
|
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
|
||||||
|
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
|
||||||
|
return f"{self.__class__.__name__}({attrs})"
|
||||||
|
|
||||||
|
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
|
||||||
|
for name in get_columns(self).keys():
|
||||||
|
if excl_non_repr and name in self.get_non_repr_fields():
|
||||||
|
continue
|
||||||
|
value = getattr(self, name)
|
||||||
|
if excl_null_value and value is None:
|
||||||
|
continue
|
||||||
|
yield name, value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_non_repr_fields() -> list[str]:
|
||||||
|
return ['id', 'date_created', 'date_updated']
|
170
src/compub/models/companies.py
Normal file
170
src/compub/models/companies.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
from datetime import date
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from slugify import slugify
|
||||||
|
from sqlalchemy.engine.base import Connection
|
||||||
|
from sqlalchemy.event.api import listens_for
|
||||||
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
|
from sqlalchemy.sql.expression import select
|
||||||
|
from sqlalchemy.sql.functions import count
|
||||||
|
from sqlalchemy.sql.schema import Column
|
||||||
|
from sqlalchemy.sql.sqltypes import Unicode
|
||||||
|
from sqlalchemy_utils.types import CountryType
|
||||||
|
from sqlmodel.main import Field, Relationship
|
||||||
|
|
||||||
|
from compub.utils import multi_max
|
||||||
|
from .base import AbstractBase
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'LegalForm',
|
||||||
|
'LegalFormSubcategory',
|
||||||
|
'Industry',
|
||||||
|
'Company',
|
||||||
|
'CompanyName',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LegalForm(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'legal_form'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
short: str = Field(max_length=32, nullable=False, index=True)
|
||||||
|
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||||
|
country: str = Field(sa_column=Column(CountryType))
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
subcategories: list['LegalFormSubcategory'] = Relationship(
|
||||||
|
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.short)
|
||||||
|
|
||||||
|
|
||||||
|
class LegalFormSubcategory(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'legal_form_subcategory'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
short: str = Field(max_length=32, nullable=False, index=True)
|
||||||
|
name: Optional[str] = Field(default=None, max_length=255, sa_column=Column(Unicode(255)))
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
legal_form_id: Optional[int] = Field(
|
||||||
|
foreign_key='legal_form.id', default=None, nullable=False, index=True
|
||||||
|
)
|
||||||
|
legal_form: Optional[LegalForm] = Relationship(
|
||||||
|
back_populates='subcategories', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
companies: list['Company'] = Relationship(
|
||||||
|
back_populates='legal_form', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.short)
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyIndustryLink(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'company_industries'
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
||||||
|
industry_id: Optional[int] = Field(foreign_key='industry.id', default=None, nullable=False, primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Industry(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'industry'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
name: str = Field(max_length=255, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
companies: list['Company'] = Relationship(
|
||||||
|
back_populates='industries', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class Company(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'company'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
visible: bool = Field(default=True, nullable=False, index=True)
|
||||||
|
insolvent: bool = Field(default=False, nullable=False, index=True)
|
||||||
|
founding_data: Optional[date]
|
||||||
|
liquidation_date: Optional[date]
|
||||||
|
city: str = Field(max_length=255, index=True) # TODO: Get rid of city; implement address properly
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
legal_form_id: Optional[int] = Field(
|
||||||
|
foreign_key='legal_form_subcategory.id', default=None, index=True
|
||||||
|
)
|
||||||
|
legal_form: Optional[LegalFormSubcategory] = Relationship(
|
||||||
|
back_populates='companies', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
address_id: Optional[int] = Field(
|
||||||
|
foreign_key='address.id', default=None, index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
industries: list[Industry] = Relationship(
|
||||||
|
back_populates='companies', link_model=CompanyIndustryLink, sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
names: list['CompanyName'] = Relationship(
|
||||||
|
back_populates='company', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.current_name or f"<Company {self.id}>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_name(self) -> Optional['CompanyName']:
|
||||||
|
return multi_max(list(self.names), CompanyName.get_reg_date, 'date_updated', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyName(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'company_name'
|
||||||
|
__MAX_NAME_LENGTH__: int = 768
|
||||||
|
__MAX_SLUG_LENGTH__: int = 255
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
name: str = Field(
|
||||||
|
max_length=__MAX_NAME_LENGTH__, sa_column=Column(Unicode(__MAX_NAME_LENGTH__), nullable=False, index=True)
|
||||||
|
)
|
||||||
|
date_registered: Optional[date]
|
||||||
|
slug: Optional[str] = Field(default=None, max_length=__MAX_SLUG_LENGTH__, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
company_id: Optional[int] = Field(
|
||||||
|
foreign_key='company.id', default=None, nullable=False, index=True
|
||||||
|
)
|
||||||
|
company: Optional[Company] = Relationship(
|
||||||
|
back_populates='names', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.name)
|
||||||
|
|
||||||
|
def get_reg_date(self) -> date:
|
||||||
|
return date(1, 1, 1) if self.date_registered is None else self.date_registered
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_slug_length(self) -> int:
|
||||||
|
return self.__MAX_SLUG_LENGTH__
|
||||||
|
|
||||||
|
|
||||||
|
@listens_for(CompanyName, 'before_insert')
|
||||||
|
def generate_company_name_slug(_mapper: Mapper, connection: Connection, target: CompanyName) -> None:
|
||||||
|
if target.slug:
|
||||||
|
return
|
||||||
|
slug = slugify(target.name)[:(target.max_slug_length - 2)]
|
||||||
|
statement = select(count()).select_from(CompanyName).where(CompanyName.slug.startswith(slug))
|
||||||
|
num = connection.execute(statement).scalar()
|
||||||
|
if num == 0:
|
||||||
|
target.slug = slug
|
||||||
|
else:
|
||||||
|
target.slug = f'{slug}-{str(num + 1)}'
|
94
src/compub/models/geography.py
Normal file
94
src/compub/models/geography.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import validator
|
||||||
|
from sqlalchemy.sql.schema import Column
|
||||||
|
from sqlalchemy.sql.sqltypes import Unicode
|
||||||
|
from sqlalchemy_utils.primitives.country import Country
|
||||||
|
from sqlalchemy_utils.types import CountryType
|
||||||
|
from sqlmodel.main import Field, Relationship
|
||||||
|
|
||||||
|
from .base import AbstractBase
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'StateProvince',
|
||||||
|
'City',
|
||||||
|
'Street',
|
||||||
|
'Address',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StateProvince(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'state_province'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
|
||||||
|
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
cities: list['City'] = Relationship(
|
||||||
|
back_populates='state_province', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator('country', pre=True)
|
||||||
|
def country_as_uppercase_string(cls, v: Any) -> str:
|
||||||
|
if isinstance(v, Country):
|
||||||
|
return v.code
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.upper()
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
|
class City(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'city'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
zip_code: str = Field(max_length=5, nullable=False, index=True)
|
||||||
|
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
state_province_id: Optional[int] = Field(
|
||||||
|
foreign_key='state_province.id', default=None, nullable=False, index=True
|
||||||
|
)
|
||||||
|
state_province: Optional[StateProvince] = Relationship(
|
||||||
|
back_populates='cities', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
streets: list['Street'] = Relationship(
|
||||||
|
back_populates='city', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Street(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'street'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
name: str = Field(max_length=255, sa_column=Column(Unicode(255), nullable=False, index=True))
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
city_id: Optional[int] = Field(
|
||||||
|
foreign_key='city.id', default=None, nullable=False, index=True
|
||||||
|
)
|
||||||
|
city: Optional[City] = Relationship(
|
||||||
|
back_populates='streets', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
addresses: list['Address'] = Relationship(
|
||||||
|
back_populates='street', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Address(AbstractBase, table=True):
|
||||||
|
__tablename__ = 'address'
|
||||||
|
|
||||||
|
# Fields
|
||||||
|
house_number: str = Field(max_length=8, nullable=False)
|
||||||
|
supplement: str = Field(max_length=255)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
street_id: Optional[int] = Field(
|
||||||
|
foreign_key='street.id', default=None, nullable=False, index=True
|
||||||
|
)
|
||||||
|
street: Optional[Street] = Relationship(
|
||||||
|
back_populates='addresses', sa_relationship_kwargs={'lazy': 'selectin'}
|
||||||
|
)
|
43
src/compub/routes.py
Normal file
43
src/compub/routes.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from .models import *
|
||||||
|
from .models.base import DB
|
||||||
|
from . import crud
|
||||||
|
|
||||||
|
|
||||||
|
api, db = FastAPI(), DB()
|
||||||
|
|
||||||
|
|
||||||
|
@api.on_event('startup')
|
||||||
|
async def initialize_db():
|
||||||
|
async with db.engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
@api.on_event("shutdown")
|
||||||
|
async def close_connection_pool():
|
||||||
|
await db.engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@api.post("/states/", response_model=StateProvince)
|
||||||
|
async def create_state(state: StateProvince, session: AsyncSession = Depends(db.get_session)):
|
||||||
|
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||||
|
# if db_obj:
|
||||||
|
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||||
|
return await crud.create_state(session, state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/states/", response_model=list[StateProvince])
|
||||||
|
async def list_states(skip: int = 0, limit: int = 100, session: AsyncSession = Depends(db.get_session)):
|
||||||
|
return await crud.get_states(session, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
# @app.post("/cities/", response_model=City)
|
||||||
|
# async def create_city(city: City, session: AsyncSession = Depends(get_session)):
|
||||||
|
# db_obj = await crud.get_state_by_name_and_country(session, name=state.name, country=state.country)
|
||||||
|
# if db_obj:
|
||||||
|
# raise HTTPException(status_code=400, detail="Name already registered")
|
||||||
|
# return await crud.create_state(session, state=state)
|
@ -3,7 +3,7 @@ import logging.config
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, ClassVar
|
from typing import Any, Callable, ClassVar
|
||||||
|
|
||||||
from pydantic import BaseModel, BaseSettings, PostgresDsn, validator
|
from pydantic import BaseModel, BaseSettings, AnyUrl, validator
|
||||||
from pydantic.env_settings import SettingsSourceCallable
|
from pydantic.env_settings import SettingsSourceCallable
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
@ -76,8 +76,12 @@ class ServerSettings(BaseModel):
|
|||||||
uds: str | None = None
|
uds: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DBURI(AnyUrl):
|
||||||
|
host_required = False
|
||||||
|
|
||||||
|
|
||||||
class Settings(AbstractBaseSettings):
|
class Settings(AbstractBaseSettings):
|
||||||
db_uri: PostgresDsn | None = None
|
db_uri: DBURI | None = None
|
||||||
server: ServerSettings = ServerSettings()
|
server: ServerSettings = ServerSettings()
|
||||||
log_config: dict | Path | None = None
|
log_config: dict | Path | None = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user