Add indices

This commit is contained in:
Daniil Fajnberg 2022-08-14 11:44:42 +02:00
parent 3f6be9adbc
commit b417ac16e0
3 changed files with 40 additions and 11 deletions

View File

@ -21,17 +21,17 @@ from sqlmodel import create_engine, Session
class DB: class DB:
def __init__(self): def __init__(self):
self._engine: AsyncEngine | Engine | None = None self._engine: AsyncEngine | Engine | None = None
self._session_maker: sessionmaker | None = None self.session_maker: sessionmaker | None = None
def start_engine(self) -> None: def start_engine(self) -> None:
if settings.db_uri is None: if settings.db_uri is None:
raise NoDatabaseConfigured raise NoDatabaseConfigured
if settings.db_uri.scheme == 'sqlite': if settings.db_uri.scheme == 'sqlite':
self._engine = create_engine(settings.db_uri) self._engine = create_engine(settings.db_uri, echo=True)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session) self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
else: else:
self._engine = create_async_engine(settings.db_uri, future=True) self._engine = create_async_engine(settings.db_uri, future=True)
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
force_auto_coercion() force_auto_coercion()
@property @property
@ -41,16 +41,26 @@ class DB:
assert isinstance(self._engine, (AsyncEngine, Engine)) assert isinstance(self._engine, (AsyncEngine, Engine))
return self._engine return self._engine
async def get_session(self) -> Session | AsyncSession: async def get_session(self) -> AsyncSession:
if self._session_maker is None: if self.session_maker is None:
self.start_engine() self.start_engine()
assert isinstance(self._session_maker, sessionmaker) assert isinstance(self.session_maker.class_, AsyncSession)
session = self._session_maker() session = self.session_maker()
try: try:
yield session yield session
finally: finally:
await session.close() await session.close()
def get_session_blocking(self) -> Session:
if self.session_maker is None:
self.start_engine()
assert isinstance(self.session_maker.class_, Session)
session = self.session_maker()
try:
yield session
finally:
session.close()
class AbstractBase(SQLModel): class AbstractBase(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)

View File

@ -7,7 +7,7 @@ from sqlalchemy.event.api import listens_for
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.sql.expression import select from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import count from sqlalchemy.sql.functions import count
from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Column, Index
from sqlalchemy.sql.sqltypes import Unicode from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.types import CountryType from sqlalchemy_utils.types import CountryType
from sqlmodel.main import Field, Relationship from sqlmodel.main import Field, Relationship
@ -45,6 +45,9 @@ class LegalForm(AbstractBase, table=True):
class LegalFormSubcategory(AbstractBase, table=True): class LegalFormSubcategory(AbstractBase, table=True):
__tablename__ = 'legal_form_subcategory' __tablename__ = 'legal_form_subcategory'
__table_args__ = (
Index('ux_legal_form_subcategory', 'short', 'legal_form_id', unique=True),
)
# Fields # Fields
short: str = Field(max_length=32, nullable=False, index=True) short: str = Field(max_length=32, nullable=False, index=True)
@ -68,6 +71,9 @@ class LegalFormSubcategory(AbstractBase, table=True):
class CompanyIndustryLink(AbstractBase, table=True): class CompanyIndustryLink(AbstractBase, table=True):
__tablename__ = 'company_industry' __tablename__ = 'company_industry'
__table_args__ = (
Index('ux_company_industry', 'company_id', 'industry_id', unique=True),
)
# Relationships # Relationships
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
@ -76,6 +82,9 @@ class CompanyIndustryLink(AbstractBase, table=True):
class CompanyExecutiveLink(AbstractBase, table=True): class CompanyExecutiveLink(AbstractBase, table=True):
__tablename__ = 'company_executive' __tablename__ = 'company_executive'
__table_args__ = (
Index('ux_company_executive', 'company_id', 'executive_id', unique=True),
)
# Relationships # Relationships
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True) company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
@ -86,7 +95,7 @@ class Industry(AbstractBase, table=True):
__tablename__ = 'industry' __tablename__ = 'industry'
# Fields # Fields
name: str = Field(max_length=255, nullable=False, index=True) name: str = Field(max_length=255, nullable=False, index=True, sa_column_kwargs={'unique': True})
# Relationships # Relationships
companies: list['Company'] = Relationship( companies: list['Company'] = Relationship(
@ -160,6 +169,10 @@ class Company(AbstractBase, table=True):
class CompanyName(AbstractBase, table=True): class CompanyName(AbstractBase, table=True):
__tablename__ = 'company_name' __tablename__ = 'company_name'
__table_args__ = (
Index('ux_company_name_company_id', 'name', 'company_id', unique=True),
)
__MAX_LENGTH_NAME__: int = 768 __MAX_LENGTH_NAME__: int = 768
__MAX_SLUG_LENGTH__: int = 255 __MAX_SLUG_LENGTH__: int = 255

View File

@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from pydantic import validator from pydantic import validator
from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Column, Index
from sqlalchemy.sql.sqltypes import Unicode from sqlalchemy.sql.sqltypes import Unicode
from sqlalchemy_utils.primitives.country import Country from sqlalchemy_utils.primitives.country import Country
from sqlalchemy_utils.types import CountryType from sqlalchemy_utils.types import CountryType
@ -20,6 +20,9 @@ __all__ = [
class StateProvince(AbstractBase, table=True): class StateProvince(AbstractBase, table=True):
__tablename__ = 'state_province' __tablename__ = 'state_province'
__table_args__ = (
Index('ux_state_province_name_country', 'name', 'country', unique=True),
)
# Fields # Fields
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True)) country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
@ -41,6 +44,9 @@ class StateProvince(AbstractBase, table=True):
class City(AbstractBase, table=True): class City(AbstractBase, table=True):
__tablename__ = 'city' __tablename__ = 'city'
__table_args__ = (
Index('ux_city_name_zip_state', 'name', 'zip_code', 'state_province_id', unique=True),
)
# Fields # Fields
zip_code: str = Field(max_length=5, nullable=False, index=True) zip_code: str = Field(max_length=5, nullable=False, index=True)