From b417ac16e01eaf1b3d4cbd15216f72dcd60d61c8 Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg Date: Sun, 14 Aug 2022 11:44:42 +0200 Subject: [PATCH] Add indices --- src/compub/models/base.py | 26 ++++++++++++++++++-------- src/compub/models/companies.py | 17 +++++++++++++++-- src/compub/models/geography.py | 8 +++++++- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/compub/models/base.py b/src/compub/models/base.py index 4e605ee..587e3c1 100644 --- a/src/compub/models/base.py +++ b/src/compub/models/base.py @@ -21,17 +21,17 @@ from sqlmodel import create_engine, Session class DB: def __init__(self): self._engine: AsyncEngine | Engine | None = None - self._session_maker: sessionmaker | None = None + self.session_maker: sessionmaker | None = None def start_engine(self) -> None: if settings.db_uri is None: raise NoDatabaseConfigured if settings.db_uri.scheme == 'sqlite': - self._engine = create_engine(settings.db_uri) - self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) + self._engine = create_engine(settings.db_uri, echo=True) + self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) else: self._engine = create_async_engine(settings.db_uri, future=True) - self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) + self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) force_auto_coercion() @property @@ -41,16 +41,26 @@ class DB: assert isinstance(self._engine, (AsyncEngine, Engine)) return self._engine - async def get_session(self) -> Session | AsyncSession: - if self._session_maker is None: + async def get_session(self) -> AsyncSession: + if self.session_maker is None: self.start_engine() - assert isinstance(self._session_maker, sessionmaker) - session = self._session_maker() + assert isinstance(self.session_maker.class_, AsyncSession) + session = self.session_maker() try: yield session finally: await session.close() + def get_session_blocking(self) -> Session: + if self.session_maker is None: + self.start_engine() + assert isinstance(self.session_maker.class_, Session) + session = self.session_maker() + try: + yield session + finally: + session.close() + class AbstractBase(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) diff --git a/src/compub/models/companies.py b/src/compub/models/companies.py index fe50194..5c7b94a 100644 --- a/src/compub/models/companies.py +++ b/src/compub/models/companies.py @@ -7,7 +7,7 @@ from sqlalchemy.event.api import listens_for from sqlalchemy.orm.mapper import Mapper from sqlalchemy.sql.expression import select from sqlalchemy.sql.functions import count -from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import Column, Index from sqlalchemy.sql.sqltypes import Unicode from sqlalchemy_utils.types import CountryType from sqlmodel.main import Field, Relationship @@ -45,6 +45,9 @@ class LegalForm(AbstractBase, table=True): class LegalFormSubcategory(AbstractBase, table=True): __tablename__ = 'legal_form_subcategory' + __table_args__ = ( + Index('ux_legal_form_subcategory', 'short', 'legal_form_id', unique=True), + ) # Fields short: str = Field(max_length=32, nullable=False, index=True) @@ -68,6 +71,9 @@ class LegalFormSubcategory(AbstractBase, table=True): class CompanyIndustryLink(AbstractBase, table=True): __tablename__ = 'company_industry' + __table_args__ = ( + Index('ux_company_industry', 'company_id', 'industry_id', unique=True), + ) # Relationships company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) @@ -76,6 +82,9 @@ class CompanyIndustryLink(AbstractBase, table=True): class CompanyExecutiveLink(AbstractBase, table=True): __tablename__ = 'company_executive' + __table_args__ = ( + Index('ux_company_executive', 'company_id', 'executive_id', unique=True), + ) # Relationships company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) @@ -86,7 +95,7 @@ class Industry(AbstractBase, table=True): __tablename__ = 'industry' # Fields - name: str = Field(max_length=255, nullable=False, index=True) + name: str = Field(max_length=255, nullable=False, index=True, sa_column_kwargs={'unique': True}) # Relationships companies: list['Company'] = Relationship( @@ -160,6 +169,10 @@ class Company(AbstractBase, table=True): class CompanyName(AbstractBase, table=True): __tablename__ = 'company_name' + __table_args__ = ( + Index('ux_company_name_company_id', 'name', 'company_id', unique=True), + ) + __MAX_LENGTH_NAME__: int = 768 __MAX_SLUG_LENGTH__: int = 255 diff --git a/src/compub/models/geography.py b/src/compub/models/geography.py index 9aed86a..db52591 100644 --- a/src/compub/models/geography.py +++ b/src/compub/models/geography.py @@ -1,7 +1,7 @@ from typing import Any, Optional from pydantic import validator -from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import Column, Index from sqlalchemy.sql.sqltypes import Unicode from sqlalchemy_utils.primitives.country import Country from sqlalchemy_utils.types import CountryType @@ -20,6 +20,9 @@ __all__ = [ class StateProvince(AbstractBase, table=True): __tablename__ = 'state_province' + __table_args__ = ( + Index('ux_state_province_name_country', 'name', 'country', unique=True), + ) # Fields country: str = Field(sa_column=Column(CountryType, nullable=False, index=True)) @@ -41,6 +44,9 @@ class StateProvince(AbstractBase, table=True): class City(AbstractBase, table=True): __tablename__ = 'city' + __table_args__ = ( + Index('ux_city_name_zip_state', 'name', 'zip_code', 'state_province_id', unique=True), + ) # Fields zip_code: str = Field(max_length=5, nullable=False, index=True)