generated from daniil-berg/boilerplate-py
Add indices
This commit is contained in:
parent
3f6be9adbc
commit
b417ac16e0
@ -21,17 +21,17 @@ from sqlmodel import create_engine, Session
|
|||||||
class DB:
|
class DB:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._engine: AsyncEngine | Engine | None = None
|
self._engine: AsyncEngine | Engine | None = None
|
||||||
self._session_maker: sessionmaker | None = None
|
self.session_maker: sessionmaker | None = None
|
||||||
|
|
||||||
def start_engine(self) -> None:
|
def start_engine(self) -> None:
|
||||||
if settings.db_uri is None:
|
if settings.db_uri is None:
|
||||||
raise NoDatabaseConfigured
|
raise NoDatabaseConfigured
|
||||||
if settings.db_uri.scheme == 'sqlite':
|
if settings.db_uri.scheme == 'sqlite':
|
||||||
self._engine = create_engine(settings.db_uri)
|
self._engine = create_engine(settings.db_uri, echo=True)
|
||||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
||||||
else:
|
else:
|
||||||
self._engine = create_async_engine(settings.db_uri, future=True)
|
self._engine = create_async_engine(settings.db_uri, future=True)
|
||||||
self._session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
||||||
force_auto_coercion()
|
force_auto_coercion()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -41,16 +41,26 @@ class DB:
|
|||||||
assert isinstance(self._engine, (AsyncEngine, Engine))
|
assert isinstance(self._engine, (AsyncEngine, Engine))
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
async def get_session(self) -> Session | AsyncSession:
|
async def get_session(self) -> AsyncSession:
|
||||||
if self._session_maker is None:
|
if self.session_maker is None:
|
||||||
self.start_engine()
|
self.start_engine()
|
||||||
assert isinstance(self._session_maker, sessionmaker)
|
assert isinstance(self.session_maker.class_, AsyncSession)
|
||||||
session = self._session_maker()
|
session = self.session_maker()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
def get_session_blocking(self) -> Session:
|
||||||
|
if self.session_maker is None:
|
||||||
|
self.start_engine()
|
||||||
|
assert isinstance(self.session_maker.class_, Session)
|
||||||
|
session = self.session_maker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
class AbstractBase(SQLModel):
|
class AbstractBase(SQLModel):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
@ -7,7 +7,7 @@ from sqlalchemy.event.api import listens_for
|
|||||||
from sqlalchemy.orm.mapper import Mapper
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
from sqlalchemy.sql.expression import select
|
from sqlalchemy.sql.expression import select
|
||||||
from sqlalchemy.sql.functions import count
|
from sqlalchemy.sql.functions import count
|
||||||
from sqlalchemy.sql.schema import Column
|
from sqlalchemy.sql.schema import Column, Index
|
||||||
from sqlalchemy.sql.sqltypes import Unicode
|
from sqlalchemy.sql.sqltypes import Unicode
|
||||||
from sqlalchemy_utils.types import CountryType
|
from sqlalchemy_utils.types import CountryType
|
||||||
from sqlmodel.main import Field, Relationship
|
from sqlmodel.main import Field, Relationship
|
||||||
@ -45,6 +45,9 @@ class LegalForm(AbstractBase, table=True):
|
|||||||
|
|
||||||
class LegalFormSubcategory(AbstractBase, table=True):
|
class LegalFormSubcategory(AbstractBase, table=True):
|
||||||
__tablename__ = 'legal_form_subcategory'
|
__tablename__ = 'legal_form_subcategory'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_legal_form_subcategory', 'short', 'legal_form_id', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
short: str = Field(max_length=32, nullable=False, index=True)
|
short: str = Field(max_length=32, nullable=False, index=True)
|
||||||
@ -68,6 +71,9 @@ class LegalFormSubcategory(AbstractBase, table=True):
|
|||||||
|
|
||||||
class CompanyIndustryLink(AbstractBase, table=True):
|
class CompanyIndustryLink(AbstractBase, table=True):
|
||||||
__tablename__ = 'company_industry'
|
__tablename__ = 'company_industry'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_company_industry', 'company_id', 'industry_id', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
||||||
@ -76,6 +82,9 @@ class CompanyIndustryLink(AbstractBase, table=True):
|
|||||||
|
|
||||||
class CompanyExecutiveLink(AbstractBase, table=True):
|
class CompanyExecutiveLink(AbstractBase, table=True):
|
||||||
__tablename__ = 'company_executive'
|
__tablename__ = 'company_executive'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_company_executive', 'company_id', 'executive_id', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
company_id: Optional[int] = Field(foreign_key='company.id', default=None, nullable=False, primary_key=True)
|
||||||
@ -86,7 +95,7 @@ class Industry(AbstractBase, table=True):
|
|||||||
__tablename__ = 'industry'
|
__tablename__ = 'industry'
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
name: str = Field(max_length=255, nullable=False, index=True)
|
name: str = Field(max_length=255, nullable=False, index=True, sa_column_kwargs={'unique': True})
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
companies: list['Company'] = Relationship(
|
companies: list['Company'] = Relationship(
|
||||||
@ -160,6 +169,10 @@ class Company(AbstractBase, table=True):
|
|||||||
|
|
||||||
class CompanyName(AbstractBase, table=True):
|
class CompanyName(AbstractBase, table=True):
|
||||||
__tablename__ = 'company_name'
|
__tablename__ = 'company_name'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_company_name_company_id', 'name', 'company_id', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
__MAX_LENGTH_NAME__: int = 768
|
__MAX_LENGTH_NAME__: int = 768
|
||||||
__MAX_SLUG_LENGTH__: int = 255
|
__MAX_SLUG_LENGTH__: int = 255
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
from sqlalchemy.sql.schema import Column
|
from sqlalchemy.sql.schema import Column, Index
|
||||||
from sqlalchemy.sql.sqltypes import Unicode
|
from sqlalchemy.sql.sqltypes import Unicode
|
||||||
from sqlalchemy_utils.primitives.country import Country
|
from sqlalchemy_utils.primitives.country import Country
|
||||||
from sqlalchemy_utils.types import CountryType
|
from sqlalchemy_utils.types import CountryType
|
||||||
@ -20,6 +20,9 @@ __all__ = [
|
|||||||
|
|
||||||
class StateProvince(AbstractBase, table=True):
|
class StateProvince(AbstractBase, table=True):
|
||||||
__tablename__ = 'state_province'
|
__tablename__ = 'state_province'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_state_province_name_country', 'name', 'country', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
|
country: str = Field(sa_column=Column(CountryType, nullable=False, index=True))
|
||||||
@ -41,6 +44,9 @@ class StateProvince(AbstractBase, table=True):
|
|||||||
|
|
||||||
class City(AbstractBase, table=True):
|
class City(AbstractBase, table=True):
|
||||||
__tablename__ = 'city'
|
__tablename__ = 'city'
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ux_city_name_zip_state', 'name', 'zip_code', 'state_province_id', unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
zip_code: str = Field(max_length=5, nullable=False, index=True)
|
zip_code: str = Field(max_length=5, nullable=False, index=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user