compub/src/compub/models/base.py

99 lines
3.5 KiB
Python

from datetime import datetime
from typing import Any, Iterator, Optional
from sqlalchemy.ext.asyncio.engine import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.future.engine import Engine
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql.functions import now as db_now
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import TIMESTAMP
from sqlalchemy_utils.functions.orm import get_columns
from sqlalchemy_utils.listeners import force_auto_coercion
from sqlalchemy_utils.types.choice import Choice
from sqlmodel.main import Field, SQLModel
from compub.exceptions import NoDatabaseConfigured
from compub.settings import settings
from sqlmodel import create_engine, Session
DEFAULT_PK_TYPE = int
class DB:
def __init__(self):
self._engine: AsyncEngine | Engine | None = None
self.session_maker: sessionmaker | None = None
def start_engine(self) -> None:
if settings.db_uri is None:
raise NoDatabaseConfigured
if settings.db_uri.scheme == 'sqlite':
self._engine = create_engine(settings.db_uri, echo=True)
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=Session)
else:
self._engine = create_async_engine(settings.db_uri, future=True)
self.session_maker = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession)
force_auto_coercion()
@property
def engine(self):
if self._engine is None:
self.start_engine()
assert isinstance(self._engine, (AsyncEngine, Engine))
return self._engine
async def get_session(self) -> AsyncSession:
if self.session_maker is None:
self.start_engine()
assert isinstance(self.session_maker.class_, AsyncSession)
session = self.session_maker()
try:
yield session
finally:
await session.close()
def get_session_blocking(self) -> Session:
if self.session_maker is None:
self.start_engine()
assert isinstance(self.session_maker.class_, Session)
session = self.session_maker()
try:
yield session
finally:
session.close()
class AbstractBase(SQLModel):
id: Optional[DEFAULT_PK_TYPE] = Field(default=None, primary_key=True)
date_created: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now())
)
date_updated: Optional[datetime] = Field(
default=None, sa_column=Column(TIMESTAMP(timezone=False), server_default=db_now(), onupdate=db_now())
)
def __repr__(self) -> str:
fields = self.iter_fields(excl_non_repr=True, excl_null_value=True)
attrs = ', '.join(f"{name}={repr(value)}" for name, value in fields)
return f"{self.__class__.__name__}({attrs})"
def iter_fields(self, excl_non_repr: bool = False, excl_null_value: bool = False) -> Iterator[tuple[str, Any]]:
for name in get_columns(self).keys():
if excl_non_repr and name in self.get_non_repr_fields():
continue
value = getattr(self, name)
if excl_null_value and value is None:
continue
yield name, value
@staticmethod
def get_non_repr_fields() -> list[str]:
return ['id', 'date_created', 'date_updated']
def get_choice_value(obj: Choice | str) -> str:
return obj.value if isinstance(obj, Choice) else obj