generated from daniil-berg/boilerplate-py
99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
from datetime import datetime
|
|
from typing import Any, Iterator, Optional
|
|
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
|
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
|
from sqlalchemy.future.engine import Engine
|
|
from sqlalchemy.orm.session import sessionmaker
|
|
from sqlalchemy.sql.functions import now as db_now
|
|
from sqlalchemy.sql.schema import Column
|
|
from sqlalchemy.sql.sqltypes import TIMESTAMP
|
|
from sqlalchemy_utils.functions.orm import get_columns
|
|
from sqlalchemy_utils.listeners import force_auto_coercion
|
|
from sqlalchemy_utils.types.choice import Choice
|
|
from sqlmodel.main import Field, SQLModel
|
|
|
|
from compub.exceptions import NoDatabaseConfigured
|
|
from compub.settings import settings
|
|
|
|
from sqlmodel import create_engine, Session
|
|
|
|
|
|
DEFAULT_PK_TYPE = int
|
|
|
|
|
|
class DB:
|
|
def __init__(self):
|
|
self._engine: AsyncEngine | Engine | None = None
|
|
self.session_maker: sessionmaker | None = None
|
|
|
|
def start_engine(self) -> None:
|
|
if settings.db_uri is None:
|
|
raise NoDatabaseConfigured
|
|
if settings.db_uri.scheme == 'sqlite':
|
|
self._engine = create_engine(settings.db_uri, echo=True)
|
|
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
|
|
else:
|
|
self._engine = create_async_engine(settings.db_uri, future=True)
|
|
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
|
|
force_auto_coercion()
|
|
|
|
@property
|
|
def engine(self):
|
|
if self._engine is None:
|
|
self.start_engine()
|
|
assert isinstance(self._engine, (AsyncEngine, Engine))
|
|
return self._engine
|
|
|
|
async def get_session(self) -> AsyncSession:
|
|
if self.session_maker is None:
|
|
self.start_engine()
|
|
assert isinstance(self.session_maker.class_, AsyncSession)
|
|
session = self.session_maker()
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
def get_session_blocking(self) -> Session:
|
|
if self.session_maker is None:
|
|
self.start_engine()
|
|
assert isinstance(self.session_maker.class_, Session)
|
|
session = self.session_maker()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
class AbstractBase(SQLModel):
|
|
id: Optional[DEFAULT_PK_TYPE] = Field(default=None, primary_key=True)
|
|
date_created: Optional[datetime] = Field(
|
|
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
|
|
)
|
|
date_updated: Optional[datetime] = Field(
|
|
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
|
|
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
|
|
return f"{self.__class__.__name__}({attrs})"
|
|
|
|
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
|
|
for name in get_columns(self).keys():
|
|
if excl_non_repr and name in self.get_non_repr_fields():
|
|
continue
|
|
value = getattr(self, name)
|
|
if excl_null_value and value is None:
|
|
continue
|
|
yield name, value
|
|
|
|
@staticmethod
|
|
def get_non_repr_fields() -> list[str]:
|
|
return ['id', 'date_created', 'date_updated']
|
|
|
|
|
|
def get_choice_value(obj: Choice | str) -> str:
|
|
return obj.value if isinstance(obj, Choice) else obj
|