generated from daniil-berg/boilerplate-py
	Add indices
This commit is contained in:
		| @@ -21,17 +21,17 @@ from sqlmodel import create_engine, Session | ||||
| class DB: | ||||
|     def __init__(self): | ||||
|         self._engine: AsyncEngine | Engine | None = None | ||||
|         self._session_maker: sessionmaker | None = None | ||||
|         self.session_maker: sessionmaker | None = None | ||||
|  | ||||
|     def start_engine(self) -> None: | ||||
|         if settings.db_uri is None: | ||||
|             raise NoDatabaseConfigured | ||||
|         if settings.db_uri.scheme == 'sqlite': | ||||
|             self._engine = create_engine(settings.db_uri) | ||||
|             self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) | ||||
|             self._engine = create_engine(settings.db_uri, echo=True) | ||||
|             self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) | ||||
|         else: | ||||
|             self._engine = create_async_engine(settings.db_uri, future=True) | ||||
|             self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) | ||||
|             self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) | ||||
|         force_auto_coercion() | ||||
|  | ||||
|     @property | ||||
| @@ -41,16 +41,26 @@ class DB: | ||||
|         assert isinstance(self._engine, (AsyncEngine, Engine)) | ||||
|         return self._engine | ||||
|  | ||||
|     async def get_session(self) -> Session | AsyncSession: | ||||
|         if self._session_maker is None: | ||||
|     async def get_session(self) -> AsyncSession: | ||||
|         if self.session_maker is None: | ||||
|             self.start_engine() | ||||
|         assert isinstance(self._session_maker, sessionmaker) | ||||
|         session = self._session_maker() | ||||
|         assert isinstance(self.session_maker.class_, AsyncSession) | ||||
|         session = self.session_maker() | ||||
|         try: | ||||
|             yield session | ||||
|         finally: | ||||
|             await session.close() | ||||
|  | ||||
|     def get_session_blocking(self) -> Session: | ||||
|         if self.session_maker is None: | ||||
|             self.start_engine() | ||||
|         assert isinstance(self.session_maker.class_, Session) | ||||
|         session = self.session_maker() | ||||
|         try: | ||||
|             yield session | ||||
|         finally: | ||||
|             session.close() | ||||
|  | ||||
|  | ||||
| class AbstractBase(SQLModel): | ||||
|     id: Optional[int] = Field(default=None, primary_key=True) | ||||
|   | ||||
| @@ -7,7 +7,7 @@ from sqlalchemy.event.api import listens_for | ||||
| from sqlalchemy.orm.mapper import Mapper | ||||
| from sqlalchemy.sql.expression import select | ||||
| from sqlalchemy.sql.functions import count | ||||
| from sqlalchemy.sql.schema import Column | ||||
| from sqlalchemy.sql.schema import Column, Index | ||||
| from sqlalchemy.sql.sqltypes import Unicode | ||||
| from sqlalchemy_utils.types import CountryType | ||||
| from sqlmodel.main import Field, Relationship | ||||
| @@ -45,6 +45,9 @@ class LegalForm(AbstractBase, table=True): | ||||
|  | ||||
| class LegalFormSubcategory(AbstractBase, table=True): | ||||
|     __tablename__ = 'legal_form_subcategory' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_legal_form_subcategory', 'short', 'legal_form_id', unique=True), | ||||
|     ) | ||||
|  | ||||
|     # Fields | ||||
|     short: str = Field(max_length=32, nullable=False, index=True) | ||||
| @@ -68,6 +71,9 @@ class LegalFormSubcategory(AbstractBase, table=True): | ||||
|  | ||||
| class CompanyIndustryLink(AbstractBase, table=True): | ||||
|     __tablename__ = 'company_industry' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_company_industry', 'company_id', 'industry_id', unique=True), | ||||
|     ) | ||||
|  | ||||
|     # Relationships | ||||
|     company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) | ||||
| @@ -76,6 +82,9 @@ class CompanyIndustryLink(AbstractBase, table=True): | ||||
|  | ||||
| class CompanyExecutiveLink(AbstractBase, table=True): | ||||
|     __tablename__ = 'company_executive' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_company_executive', 'company_id', 'executive_id', unique=True), | ||||
|     ) | ||||
|  | ||||
|     # Relationships | ||||
|     company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) | ||||
| @@ -86,7 +95,7 @@ class Industry(AbstractBase, table=True): | ||||
|     __tablename__ = 'industry' | ||||
|  | ||||
|     # Fields | ||||
|     name: str = Field(max_length=255, nullable=False, index=True) | ||||
|     name: str = Field(max_length=255, nullable=False, index=True, sa_column_kwargs={'unique': True}) | ||||
|  | ||||
|     # Relationships | ||||
|     companies: list['Company'] = Relationship( | ||||
| @@ -160,6 +169,10 @@ class Company(AbstractBase, table=True): | ||||
|  | ||||
| class CompanyName(AbstractBase, table=True): | ||||
|     __tablename__ = 'company_name' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_company_name_company_id', 'name', 'company_id', unique=True), | ||||
|     ) | ||||
|  | ||||
|     __MAX_LENGTH_NAME__: int = 768 | ||||
|     __MAX_SLUG_LENGTH__: int = 255 | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from typing import Any, Optional | ||||
|  | ||||
| from pydantic import validator | ||||
| from sqlalchemy.sql.schema import Column | ||||
| from sqlalchemy.sql.schema import Column, Index | ||||
| from sqlalchemy.sql.sqltypes import Unicode | ||||
| from sqlalchemy_utils.primitives.country import Country | ||||
| from sqlalchemy_utils.types import CountryType | ||||
| @@ -20,6 +20,9 @@ __all__ = [ | ||||
|  | ||||
| class StateProvince(AbstractBase, table=True): | ||||
|     __tablename__ = 'state_province' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_state_province_name_country', 'name', 'country', unique=True), | ||||
|     ) | ||||
|  | ||||
|     # Fields | ||||
|     country: str = Field(sa_column=Column(CountryType, nullable=False, index=True)) | ||||
| @@ -41,6 +44,9 @@ class StateProvince(AbstractBase, table=True): | ||||
|  | ||||
| class City(AbstractBase, table=True): | ||||
|     __tablename__ = 'city' | ||||
|     __table_args__ = ( | ||||
|         Index('ux_city_name_zip_state', 'name', 'zip_code', 'state_province_id', unique=True), | ||||
|     ) | ||||
|  | ||||
|     # Fields | ||||
|     zip_code: str = Field(max_length=5, nullable=False, index=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user